tensogram-xarray 0.14.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tensogram_xarray-0.14.0/.gitignore +26 -0
- tensogram_xarray-0.14.0/PKG-INFO +23 -0
- tensogram_xarray-0.14.0/pyproject.toml +65 -0
- tensogram_xarray-0.14.0/src/tensogram_xarray/__init__.py +24 -0
- tensogram_xarray-0.14.0/src/tensogram_xarray/array.py +408 -0
- tensogram_xarray-0.14.0/src/tensogram_xarray/backend.py +139 -0
- tensogram_xarray-0.14.0/src/tensogram_xarray/coords.py +113 -0
- tensogram_xarray-0.14.0/src/tensogram_xarray/mapping.py +91 -0
- tensogram_xarray-0.14.0/src/tensogram_xarray/merge.py +832 -0
- tensogram_xarray-0.14.0/src/tensogram_xarray/scanner.py +196 -0
- tensogram_xarray-0.14.0/src/tensogram_xarray/store.py +383 -0
- tensogram_xarray-0.14.0/tests/__init__.py +8 -0
- tensogram_xarray-0.14.0/tests/conftest.py +203 -0
- tensogram_xarray-0.14.0/tests/test_array.py +149 -0
- tensogram_xarray-0.14.0/tests/test_backend.py +133 -0
- tensogram_xarray-0.14.0/tests/test_coords.py +92 -0
- tensogram_xarray-0.14.0/tests/test_coverage.py +1550 -0
- tensogram_xarray-0.14.0/tests/test_edge_cases.py +487 -0
- tensogram_xarray-0.14.0/tests/test_mapping.py +70 -0
- tensogram_xarray-0.14.0/tests/test_merge.py +190 -0
- tensogram_xarray-0.14.0/tests/test_nd_range.py +213 -0
- tensogram_xarray-0.14.0/tests/test_remote.py +136 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
.claude/*
|
|
2
|
+
!.claude/commands/
|
|
3
|
+
.weave/
|
|
4
|
+
.sisyphus/
|
|
5
|
+
.coverage
|
|
6
|
+
*.dylib.dSYM/
|
|
7
|
+
**/target
|
|
8
|
+
**/pkg
|
|
9
|
+
**/build/
|
|
10
|
+
python/bindings/Cargo.lock
|
|
11
|
+
|
|
12
|
+
/docs/book
|
|
13
|
+
**/.venv
|
|
14
|
+
**/.ruff_cache
|
|
15
|
+
**/__pycache__
|
|
16
|
+
*.so
|
|
17
|
+
*.dylib
|
|
18
|
+
*.pyd
|
|
19
|
+
*.swp
|
|
20
|
+
*.swo
|
|
21
|
+
*~
|
|
22
|
+
.DS_Store
|
|
23
|
+
.idea/
|
|
24
|
+
rust/tensogram-grib/Cargo.lock
|
|
25
|
+
rust/tensogram-netcdf/Cargo.lock
|
|
26
|
+
rust/tensogram-wasm/Cargo.lock
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: tensogram-xarray
|
|
3
|
+
Version: 0.14.0
|
|
4
|
+
Summary: xarray backend engine for tensogram .tgm files
|
|
5
|
+
Project-URL: Homepage, https://sites.ecmwf.int/docs/tensogram/main
|
|
6
|
+
Project-URL: Repository, https://github.com/ecmwf/tensogram
|
|
7
|
+
Project-URL: Documentation, https://sites.ecmwf.int/docs/tensogram/main
|
|
8
|
+
Author-email: ECMWF <software@ecmwf.int>
|
|
9
|
+
License-Expression: Apache-2.0
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering :: Atmospheric Science
|
|
15
|
+
Requires-Python: >=3.9
|
|
16
|
+
Requires-Dist: numpy
|
|
17
|
+
Requires-Dist: tensogram<0.15,>=0.14.0
|
|
18
|
+
Requires-Dist: xarray>=2022.06
|
|
19
|
+
Provides-Extra: dask
|
|
20
|
+
Requires-Dist: dask[array]; extra == 'dask'
|
|
21
|
+
Provides-Extra: dev
|
|
22
|
+
Requires-Dist: pytest>=7.0; extra == 'dev'
|
|
23
|
+
Requires-Dist: ruff>=0.4; extra == 'dev'
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "tensogram-xarray"
|
|
7
|
+
version = "0.14.0"
|
|
8
|
+
description = "xarray backend engine for tensogram .tgm files"
|
|
9
|
+
requires-python = ">=3.9"
|
|
10
|
+
license = "Apache-2.0"
|
|
11
|
+
authors = [{name = "ECMWF", email = "software@ecmwf.int"}]
|
|
12
|
+
classifiers = [
|
|
13
|
+
"Development Status :: 4 - Beta",
|
|
14
|
+
"License :: OSI Approved :: Apache Software License",
|
|
15
|
+
"Programming Language :: Python :: 3",
|
|
16
|
+
"Topic :: Scientific/Engineering",
|
|
17
|
+
"Topic :: Scientific/Engineering :: Atmospheric Science",
|
|
18
|
+
]
|
|
19
|
+
dependencies = [
|
|
20
|
+
"tensogram>=0.14.0,<0.15",
|
|
21
|
+
"xarray>=2022.06",
|
|
22
|
+
"numpy",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
[project.urls]
|
|
26
|
+
Homepage = "https://sites.ecmwf.int/docs/tensogram/main"
|
|
27
|
+
Repository = "https://github.com/ecmwf/tensogram"
|
|
28
|
+
Documentation = "https://sites.ecmwf.int/docs/tensogram/main"
|
|
29
|
+
|
|
30
|
+
[project.optional-dependencies]
|
|
31
|
+
dev = ["pytest>=7.0", "ruff>=0.4"]
|
|
32
|
+
dask = ["dask[array]"]
|
|
33
|
+
|
|
34
|
+
[project.entry-points."xarray.backends"]
|
|
35
|
+
tensogram = "tensogram_xarray.backend:TensogramBackendEntrypoint"
|
|
36
|
+
|
|
37
|
+
[tool.hatch.build.targets.wheel]
|
|
38
|
+
packages = ["src/tensogram_xarray"]
|
|
39
|
+
|
|
40
|
+
[tool.ruff]
|
|
41
|
+
line-length = 99
|
|
42
|
+
target-version = "py39"
|
|
43
|
+
|
|
44
|
+
[tool.ruff.lint]
|
|
45
|
+
select = [
|
|
46
|
+
"E", # pycodestyle errors
|
|
47
|
+
"W", # pycodestyle warnings
|
|
48
|
+
"F", # pyflakes
|
|
49
|
+
"I", # isort
|
|
50
|
+
"N", # pep8-naming
|
|
51
|
+
"UP", # pyupgrade
|
|
52
|
+
"B", # flake8-bugbear
|
|
53
|
+
"SIM", # flake8-simplify
|
|
54
|
+
"PT", # flake8-pytest-style
|
|
55
|
+
"RUF", # ruff-specific rules
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
[tool.ruff.lint.per-file-ignores]
|
|
59
|
+
"tests/**" = ["RUF012"] # mutable class attrs in test mock classes are fine
|
|
60
|
+
|
|
61
|
+
[tool.ruff.lint.isort]
|
|
62
|
+
known-third-party = ["numpy", "xarray", "tensogram", "pytest"]
|
|
63
|
+
|
|
64
|
+
[tool.pytest.ini_options]
|
|
65
|
+
testpaths = ["tests"]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# (C) Copyright 2026- ECMWF and individual contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation nor
|
|
7
|
+
# does it submit to any jurisdiction.
|
|
8
|
+
|
|
9
|
+
"""tensogram-xarray: xarray backend engine for tensogram .tgm files.
|
|
10
|
+
|
|
11
|
+
Provides ``engine="tensogram"`` for ``xr.open_dataset()`` and a top-level
|
|
12
|
+
``open_datasets()`` function for multi-message .tgm files that auto-groups
|
|
13
|
+
incompatible objects into separate Datasets.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from tensogram_xarray.backend import TensogramBackendEntrypoint
|
|
19
|
+
from tensogram_xarray.merge import open_datasets
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"TensogramBackendEntrypoint",
|
|
23
|
+
"open_datasets",
|
|
24
|
+
]
|
|
@@ -0,0 +1,408 @@
|
|
|
1
|
+
# (C) Copyright 2026- ECMWF and individual contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation nor
|
|
7
|
+
# does it submit to any jurisdiction.
|
|
8
|
+
|
|
9
|
+
"""Lazy-loading backend array for tensogram data objects.
|
|
10
|
+
|
|
11
|
+
``TensogramBackendArray`` implements :class:`xarray.backends.BackendArray` so
|
|
12
|
+
that tensor payloads are decoded on demand. For compressors that support
|
|
13
|
+
random access (``none``, ``szip``, ``blosc2``, ``zfp`` fixed-rate) and have no
|
|
14
|
+
``shuffle`` filter, N-dimensional slice requests are mapped to flat element
|
|
15
|
+
ranges and decoded via ``tensogram.decode_range()``. Otherwise the full
|
|
16
|
+
object is decoded and sliced in-memory via ``tensogram.decode_object()``.
|
|
17
|
+
|
|
18
|
+
A ratio-based heuristic controls when partial reads are used: if the
|
|
19
|
+
fraction of requested elements exceeds ``range_threshold`` (default 0.5),
|
|
20
|
+
the backend falls back to a full decode.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import logging
|
|
26
|
+
import math
|
|
27
|
+
import os
|
|
28
|
+
import threading
|
|
29
|
+
from itertools import product as iterproduct
|
|
30
|
+
from typing import Any
|
|
31
|
+
|
|
32
|
+
import numpy as np
|
|
33
|
+
from xarray.backends import BackendArray
|
|
34
|
+
from xarray.core import indexing
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
# Compressor values that support partial decode via decode_range().
|
|
39
|
+
_RANDOM_ACCESS_COMPRESSORS = frozenset({"none", "szip", "blosc2", "zfp"})
|
|
40
|
+
|
|
41
|
+
# Filters that break contiguous byte ranges (shuffle rearranges bytes).
|
|
42
|
+
_RANGE_BLOCKING_FILTERS = frozenset({"shuffle"})
|
|
43
|
+
|
|
44
|
+
# Default ratio threshold: use decode_range when the requested fraction of
|
|
45
|
+
# total elements is at or below this value.
|
|
46
|
+
DEFAULT_RANGE_THRESHOLD = 0.5
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _supports_range_decode(descriptor: Any) -> bool:
|
|
50
|
+
"""Return *True* if the object's pipeline allows ``decode_range()``."""
|
|
51
|
+
compression = getattr(descriptor, "compression", "none")
|
|
52
|
+
filt = getattr(descriptor, "filter", "none")
|
|
53
|
+
|
|
54
|
+
if filt in _RANGE_BLOCKING_FILTERS:
|
|
55
|
+
return False
|
|
56
|
+
|
|
57
|
+
if compression not in _RANDOM_ACCESS_COMPRESSORS:
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
# zfp supports range decode only in fixed_rate mode.
|
|
61
|
+
if compression == "zfp":
|
|
62
|
+
params = getattr(descriptor, "params", {}) or {}
|
|
63
|
+
if params.get("zfp_mode") != "fixed_rate":
|
|
64
|
+
return False
|
|
65
|
+
|
|
66
|
+
return True
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _is_contiguous_slice(key: tuple) -> bool:
|
|
70
|
+
"""Return *True* when *key* is a tuple of unit-stride slices."""
|
|
71
|
+
for k in key:
|
|
72
|
+
if not isinstance(k, slice):
|
|
73
|
+
return False
|
|
74
|
+
# Reject non-unit strides (step != 1 and step != None).
|
|
75
|
+
if k.step is not None and k.step != 1:
|
|
76
|
+
return False
|
|
77
|
+
return True
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# ---------------------------------------------------------------------------
|
|
81
|
+
# N-D slice -> flat element ranges
|
|
82
|
+
# ---------------------------------------------------------------------------
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _nd_slice_to_flat_ranges(
|
|
86
|
+
shape: tuple[int, ...],
|
|
87
|
+
key: tuple[slice, ...],
|
|
88
|
+
) -> tuple[list[tuple[int, int]], tuple[int, ...]]:
|
|
89
|
+
"""Map an N-dimensional slice to flat ``(start, count)`` ranges.
|
|
90
|
+
|
|
91
|
+
For a C-contiguous (row-major) array the elements of a hyper-rectangular
|
|
92
|
+
slice are **not** contiguous in general. Contiguous runs exist only
|
|
93
|
+
along the innermost (rightmost) axis. This function decomposes the
|
|
94
|
+
N-D slice into the minimal set of flat ranges that cover exactly the
|
|
95
|
+
requested elements, then merges adjacent ranges.
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
shape
|
|
100
|
+
Shape of the full tensor.
|
|
101
|
+
key
|
|
102
|
+
Tuple of ``slice`` objects (one per dimension, unit-stride only).
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
flat_ranges
|
|
107
|
+
List of ``(element_offset, element_count)`` in the flattened array.
|
|
108
|
+
output_shape
|
|
109
|
+
Shape of the result after slicing.
|
|
110
|
+
"""
|
|
111
|
+
ndim = len(shape)
|
|
112
|
+
|
|
113
|
+
# Parse each slice into (start, count).
|
|
114
|
+
dim_ranges: list[tuple[int, int]] = []
|
|
115
|
+
output_dims: list[int] = []
|
|
116
|
+
for slc, d in zip(key, shape):
|
|
117
|
+
s, e, _ = slc.indices(d)
|
|
118
|
+
count = e - s
|
|
119
|
+
dim_ranges.append((s, count))
|
|
120
|
+
output_dims.append(count)
|
|
121
|
+
output_shape = tuple(output_dims)
|
|
122
|
+
|
|
123
|
+
total_output = math.prod(output_dims)
|
|
124
|
+
if total_output == 0:
|
|
125
|
+
return [], output_shape
|
|
126
|
+
|
|
127
|
+
# Compute C-contiguous strides (in elements).
|
|
128
|
+
strides = [1] * ndim
|
|
129
|
+
for i in range(ndim - 2, -1, -1):
|
|
130
|
+
strides[i] = strides[i + 1] * shape[i + 1]
|
|
131
|
+
|
|
132
|
+
# Find the *split point* k: the innermost dimension whose slice is
|
|
133
|
+
# NOT a full slice. All dimensions k+1 .. n-1 are full slices, so
|
|
134
|
+
# their elements form a contiguous block.
|
|
135
|
+
split = -1 # -1 means all dims are full
|
|
136
|
+
for i in range(ndim - 1, -1, -1):
|
|
137
|
+
start_i, count_i = dim_ranges[i]
|
|
138
|
+
if start_i != 0 or count_i != shape[i]:
|
|
139
|
+
split = i
|
|
140
|
+
break
|
|
141
|
+
|
|
142
|
+
if split == -1:
|
|
143
|
+
# Every dimension is a full slice -- one range covering everything.
|
|
144
|
+
return [(0, math.prod(shape))], output_shape
|
|
145
|
+
|
|
146
|
+
# Contiguous block size: count at split dim * product of trailing dims.
|
|
147
|
+
block_size = dim_ranges[split][1]
|
|
148
|
+
for i in range(split + 1, ndim):
|
|
149
|
+
block_size *= shape[i]
|
|
150
|
+
|
|
151
|
+
block_start_within_row = dim_ranges[split][0] * strides[split]
|
|
152
|
+
|
|
153
|
+
if split == 0:
|
|
154
|
+
# No outer dimensions to iterate.
|
|
155
|
+
return [(dim_ranges[0][0] * strides[0], block_size)], output_shape
|
|
156
|
+
|
|
157
|
+
# Generate one range per combination of outer-dimension indices.
|
|
158
|
+
outer_index_ranges = [
|
|
159
|
+
range(dim_ranges[i][0], dim_ranges[i][0] + dim_ranges[i][1]) for i in range(split)
|
|
160
|
+
]
|
|
161
|
+
|
|
162
|
+
flat_ranges: list[tuple[int, int]] = []
|
|
163
|
+
for idx in iterproduct(*outer_index_ranges):
|
|
164
|
+
base = sum(idx[j] * strides[j] for j in range(split))
|
|
165
|
+
flat_ranges.append((base + block_start_within_row, block_size))
|
|
166
|
+
|
|
167
|
+
# Merge adjacent ranges (consecutive with no gap).
|
|
168
|
+
if len(flat_ranges) > 1:
|
|
169
|
+
merged: list[tuple[int, int]] = [flat_ranges[0]]
|
|
170
|
+
for start, count in flat_ranges[1:]:
|
|
171
|
+
prev_start, prev_count = merged[-1]
|
|
172
|
+
if start == prev_start + prev_count:
|
|
173
|
+
merged[-1] = (prev_start, prev_count + count)
|
|
174
|
+
else:
|
|
175
|
+
merged.append((start, count))
|
|
176
|
+
flat_ranges = merged
|
|
177
|
+
|
|
178
|
+
return flat_ranges, output_shape
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
# ---------------------------------------------------------------------------
|
|
182
|
+
# Backend array
|
|
183
|
+
# ---------------------------------------------------------------------------
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class TensogramBackendArray(BackendArray):
|
|
187
|
+
"""Lazy array backed by a tensogram file.
|
|
188
|
+
|
|
189
|
+
Stores the file path (or remote URL) and optionally a shared file handle.
|
|
190
|
+
The handle is dropped on pickle for dask multiprocessing compatibility
|
|
191
|
+
and lazily reopened on the worker.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
def __init__(
|
|
195
|
+
self,
|
|
196
|
+
file_path: str,
|
|
197
|
+
msg_index: int,
|
|
198
|
+
obj_index: int,
|
|
199
|
+
shape: tuple[int, ...],
|
|
200
|
+
dtype: np.dtype,
|
|
201
|
+
supports_range: bool,
|
|
202
|
+
*,
|
|
203
|
+
verify_hash: bool = False,
|
|
204
|
+
range_threshold: float = DEFAULT_RANGE_THRESHOLD,
|
|
205
|
+
lock: threading.Lock | None = None,
|
|
206
|
+
storage_options: dict[str, Any] | None = None,
|
|
207
|
+
shared_file: Any | None = None,
|
|
208
|
+
):
|
|
209
|
+
import tensogram
|
|
210
|
+
|
|
211
|
+
self._is_remote = tensogram.is_remote_url(file_path)
|
|
212
|
+
self.file_path = file_path if self._is_remote else os.path.abspath(file_path)
|
|
213
|
+
self.msg_index = msg_index
|
|
214
|
+
self.obj_index = obj_index
|
|
215
|
+
self.shape = shape
|
|
216
|
+
self.dtype = dtype
|
|
217
|
+
self.supports_range = supports_range
|
|
218
|
+
self.verify_hash = verify_hash
|
|
219
|
+
self.range_threshold = range_threshold
|
|
220
|
+
self.storage_options = storage_options
|
|
221
|
+
self._shared_file = shared_file
|
|
222
|
+
|
|
223
|
+
# -- pickle support (no open handles stored) ----------------------------
|
|
224
|
+
|
|
225
|
+
def __getstate__(self) -> dict:
|
|
226
|
+
state = self.__dict__.copy()
|
|
227
|
+
state["_shared_file"] = None
|
|
228
|
+
return state
|
|
229
|
+
|
|
230
|
+
def __setstate__(self, state: dict) -> None:
|
|
231
|
+
self.__dict__.update(state)
|
|
232
|
+
self._shared_file = None
|
|
233
|
+
|
|
234
|
+
# -- BackendArray interface ---------------------------------------------
|
|
235
|
+
|
|
236
|
+
def __getitem__(self, key: indexing.ExplicitIndexer) -> np.ndarray:
|
|
237
|
+
return indexing.explicit_indexing_adapter(
|
|
238
|
+
key,
|
|
239
|
+
self.shape,
|
|
240
|
+
indexing.IndexingSupport.BASIC,
|
|
241
|
+
self._raw_indexing_method,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def _get_file(self):
|
|
245
|
+
if self._shared_file is not None:
|
|
246
|
+
return self._shared_file
|
|
247
|
+
import tensogram
|
|
248
|
+
|
|
249
|
+
if self._is_remote:
|
|
250
|
+
return tensogram.TensogramFile.open_remote(self.file_path, self.storage_options or {})
|
|
251
|
+
return tensogram.TensogramFile.open(self.file_path)
|
|
252
|
+
|
|
253
|
+
def _raw_indexing_method(self, key: tuple) -> np.ndarray:
|
|
254
|
+
import tensogram
|
|
255
|
+
|
|
256
|
+
if self._shared_file is not None:
|
|
257
|
+
return self._read_from_file(self._shared_file, key, tensogram)
|
|
258
|
+
|
|
259
|
+
with self._get_file() as f:
|
|
260
|
+
return self._read_from_file(f, key, tensogram)
|
|
261
|
+
|
|
262
|
+
def _read_from_file(self, f, key: tuple, tensogram) -> np.ndarray:
|
|
263
|
+
if self.supports_range and _is_contiguous_slice(key):
|
|
264
|
+
try:
|
|
265
|
+
flat_ranges, out_shape = _nd_slice_to_flat_ranges(self.shape, key)
|
|
266
|
+
total_requested = sum(c for _, c in flat_ranges)
|
|
267
|
+
total_elements = math.prod(self.shape)
|
|
268
|
+
|
|
269
|
+
if total_elements > 0 and total_requested / total_elements <= self.range_threshold:
|
|
270
|
+
arr = f.file_decode_range(
|
|
271
|
+
self.msg_index,
|
|
272
|
+
obj_index=self.obj_index,
|
|
273
|
+
ranges=flat_ranges,
|
|
274
|
+
join=True,
|
|
275
|
+
verify_hash=self.verify_hash,
|
|
276
|
+
native_byte_order=True,
|
|
277
|
+
)
|
|
278
|
+
return np.asarray(arr).reshape(out_shape)
|
|
279
|
+
except (ValueError, RuntimeError, OSError) as exc:
|
|
280
|
+
logger.debug(
|
|
281
|
+
"decode_range failed for %s msg=%d obj=%d, falling back to full decode: %s",
|
|
282
|
+
self.file_path,
|
|
283
|
+
self.msg_index,
|
|
284
|
+
self.obj_index,
|
|
285
|
+
exc,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
if self._is_remote:
|
|
289
|
+
result = f.file_decode_object(
|
|
290
|
+
self.msg_index,
|
|
291
|
+
self.obj_index,
|
|
292
|
+
verify_hash=self.verify_hash,
|
|
293
|
+
)
|
|
294
|
+
return np.asarray(result["data"][key])
|
|
295
|
+
|
|
296
|
+
raw_msg = f.read_message(self.msg_index)
|
|
297
|
+
_meta, _desc, arr = tensogram.decode_object(
|
|
298
|
+
raw_msg,
|
|
299
|
+
self.obj_index,
|
|
300
|
+
verify_hash=self.verify_hash,
|
|
301
|
+
)
|
|
302
|
+
return np.asarray(arr[key])
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
# ---------------------------------------------------------------------------
|
|
306
|
+
# Stacked backend array (lazy hypercube)
|
|
307
|
+
# ---------------------------------------------------------------------------
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class StackedBackendArray(BackendArray):
|
|
311
|
+
"""Lazy stacked array composed of multiple :class:`TensogramBackendArray`.
|
|
312
|
+
|
|
313
|
+
Each position along the outer dimensions maps to a separate backing
|
|
314
|
+
array. Indexing dispatches to the correct backing array(s) and
|
|
315
|
+
assembles the result, so no data is decoded until actually accessed.
|
|
316
|
+
"""
|
|
317
|
+
|
|
318
|
+
def __init__(
|
|
319
|
+
self,
|
|
320
|
+
arrays: list[TensogramBackendArray],
|
|
321
|
+
outer_shape: tuple[int, ...],
|
|
322
|
+
inner_shape: tuple[int, ...],
|
|
323
|
+
dtype: np.dtype,
|
|
324
|
+
):
|
|
325
|
+
if len(arrays) != math.prod(outer_shape):
|
|
326
|
+
msg = (
|
|
327
|
+
f"StackedBackendArray: expected {math.prod(outer_shape)} "
|
|
328
|
+
f"backing arrays for outer_shape={outer_shape}, "
|
|
329
|
+
f"got {len(arrays)}"
|
|
330
|
+
)
|
|
331
|
+
raise ValueError(msg)
|
|
332
|
+
|
|
333
|
+
self._arrays = arrays
|
|
334
|
+
self._outer_shape = outer_shape
|
|
335
|
+
self._inner_shape = inner_shape
|
|
336
|
+
self.shape = outer_shape + inner_shape
|
|
337
|
+
self.dtype = dtype
|
|
338
|
+
|
|
339
|
+
def __getitem__(self, key: indexing.ExplicitIndexer) -> np.ndarray:
|
|
340
|
+
return indexing.explicit_indexing_adapter(
|
|
341
|
+
key,
|
|
342
|
+
self.shape,
|
|
343
|
+
indexing.IndexingSupport.BASIC,
|
|
344
|
+
self._raw_indexing_method,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
def _raw_indexing_method(self, key: tuple) -> np.ndarray:
|
|
348
|
+
n_outer = len(self._outer_shape)
|
|
349
|
+
|
|
350
|
+
# Split key into outer and inner parts.
|
|
351
|
+
outer_key = key[:n_outer]
|
|
352
|
+
inner_key = key[n_outer:]
|
|
353
|
+
|
|
354
|
+
# Compute which backing arrays are needed.
|
|
355
|
+
outer_indices = _expand_key_to_indices(outer_key, self._outer_shape)
|
|
356
|
+
|
|
357
|
+
# Determine output shape for outer dimensions.
|
|
358
|
+
outer_out_shape = tuple(len(idx) for idx in outer_indices)
|
|
359
|
+
|
|
360
|
+
# Compute inner output shape from inner_key: slices preserve the
|
|
361
|
+
# dimension (with the slice length), integer keys drop it -- matching
|
|
362
|
+
# numpy's basic-indexing semantics.
|
|
363
|
+
inner_out_shape = tuple(
|
|
364
|
+
len(range(*k.indices(s)))
|
|
365
|
+
for k, s in zip(inner_key, self._inner_shape)
|
|
366
|
+
if isinstance(k, slice)
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
result = np.empty(outer_out_shape + inner_out_shape, dtype=self.dtype)
|
|
370
|
+
|
|
371
|
+
for flat_pos, combo in enumerate(iterproduct(*outer_indices)):
|
|
372
|
+
# Map N-D outer index to flat backing-array index (row-major).
|
|
373
|
+
flat_idx = 0
|
|
374
|
+
for dim, idx_val in enumerate(combo):
|
|
375
|
+
stride = 1
|
|
376
|
+
for d2 in range(dim + 1, n_outer):
|
|
377
|
+
stride *= self._outer_shape[d2]
|
|
378
|
+
flat_idx += idx_val * stride
|
|
379
|
+
|
|
380
|
+
backing = self._arrays[flat_idx]
|
|
381
|
+
inner_data = backing._raw_indexing_method(inner_key)
|
|
382
|
+
|
|
383
|
+
# Unravel flat_pos into N-D output position (row-major / C order).
|
|
384
|
+
# iterproduct iterates in row-major order (rightmost index varies
|
|
385
|
+
# fastest), so unraveling must go right-to-left.
|
|
386
|
+
out_idx: list[int] = []
|
|
387
|
+
remainder = flat_pos
|
|
388
|
+
for size in reversed(outer_out_shape):
|
|
389
|
+
out_idx.append(remainder % size)
|
|
390
|
+
remainder //= size
|
|
391
|
+
out_idx.reverse()
|
|
392
|
+
result[tuple(out_idx)] = inner_data
|
|
393
|
+
|
|
394
|
+
# Apply outer slicing to produce correct output shape.
|
|
395
|
+
return result
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def _expand_key_to_indices(key: tuple, shape: tuple[int, ...]) -> list[list[int]]:
|
|
399
|
+
"""Expand a tuple of slices/ints into lists of concrete indices."""
|
|
400
|
+
result: list[list[int]] = []
|
|
401
|
+
for k, size in zip(key, shape):
|
|
402
|
+
if isinstance(k, slice):
|
|
403
|
+
result.append(list(range(*k.indices(size))))
|
|
404
|
+
elif isinstance(k, int):
|
|
405
|
+
result.append([k])
|
|
406
|
+
else:
|
|
407
|
+
result.append(list(range(size)))
|
|
408
|
+
return result
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# (C) Copyright 2026- ECMWF and individual contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation nor
|
|
7
|
+
# does it submit to any jurisdiction.
|
|
8
|
+
|
|
9
|
+
"""xarray backend entry point for tensogram ``.tgm`` files.
|
|
10
|
+
|
|
11
|
+
Registers ``engine="tensogram"`` with xarray via the ``xarray.backends``
|
|
12
|
+
entry point in ``pyproject.toml``.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
from collections.abc import Iterable, Sequence
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
import xarray as xr
|
|
22
|
+
from xarray.backends import BackendEntrypoint
|
|
23
|
+
|
|
24
|
+
from tensogram_xarray.store import TensogramDataStore
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TensogramBackendEntrypoint(BackendEntrypoint):
|
|
28
|
+
"""Open tensogram ``.tgm`` files as xarray Datasets.
|
|
29
|
+
|
|
30
|
+
Usage::
|
|
31
|
+
|
|
32
|
+
import xarray as xr
|
|
33
|
+
|
|
34
|
+
# Simple open (single message, generic dim names)
|
|
35
|
+
ds = xr.open_dataset("file.tgm", engine="tensogram")
|
|
36
|
+
|
|
37
|
+
# With user-specified dimension names
|
|
38
|
+
ds = xr.open_dataset("file.tgm", engine="tensogram",
|
|
39
|
+
dim_names=["latitude", "longitude"])
|
|
40
|
+
|
|
41
|
+
# With variable naming from metadata
|
|
42
|
+
ds = xr.open_dataset("file.tgm", engine="tensogram",
|
|
43
|
+
variable_key="mars.param")
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
description = "Open tensogram .tgm files in xarray"
|
|
47
|
+
url = "https://github.com/ecmwf/tensogram"
|
|
48
|
+
|
|
49
|
+
def open_dataset( # type: ignore[override]
|
|
50
|
+
self,
|
|
51
|
+
filename_or_obj: str | os.PathLike,
|
|
52
|
+
*,
|
|
53
|
+
drop_variables: Iterable[str] | None = None,
|
|
54
|
+
dim_names: Sequence[str] | None = None,
|
|
55
|
+
variable_key: str | None = None,
|
|
56
|
+
message_index: int = 0,
|
|
57
|
+
merge_objects: bool = False,
|
|
58
|
+
verify_hash: bool = False,
|
|
59
|
+
range_threshold: float = 0.5,
|
|
60
|
+
storage_options: dict[str, Any] | None = None,
|
|
61
|
+
) -> xr.Dataset:
|
|
62
|
+
"""Open a single tensogram message as an :class:`xr.Dataset`.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
filename_or_obj
|
|
67
|
+
Path to a ``.tgm`` file.
|
|
68
|
+
drop_variables
|
|
69
|
+
Variable names to exclude from the Dataset.
|
|
70
|
+
dim_names
|
|
71
|
+
Explicit dimension names for data variables. Must have exactly
|
|
72
|
+
as many entries as the tensor has axes.
|
|
73
|
+
variable_key
|
|
74
|
+
Dotted metadata path (e.g. ``"mars.param"``) whose value at each
|
|
75
|
+
data object becomes the xarray variable name.
|
|
76
|
+
message_index
|
|
77
|
+
Which message to open when the file contains multiple messages.
|
|
78
|
+
merge_objects
|
|
79
|
+
If *True*, attempt to merge objects across messages by stacking
|
|
80
|
+
along metadata dimensions that vary. When *False* (default),
|
|
81
|
+
only the single message at *message_index* is opened.
|
|
82
|
+
verify_hash
|
|
83
|
+
Whether to verify xxh3 hashes during decode.
|
|
84
|
+
range_threshold
|
|
85
|
+
Maximum fraction of total array elements (0.0-1.0) for which
|
|
86
|
+
partial ``decode_range()`` is used instead of a full
|
|
87
|
+
``decode_object()``. Default ``0.5`` (50%).
|
|
88
|
+
storage_options
|
|
89
|
+
Key-value pairs forwarded to the object store backend when
|
|
90
|
+
the path is a remote URL. Ignored for local files.
|
|
91
|
+
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
xr.Dataset
|
|
95
|
+
"""
|
|
96
|
+
file_path = str(filename_or_obj)
|
|
97
|
+
|
|
98
|
+
if message_index < 0:
|
|
99
|
+
msg = f"message_index must be >= 0, got {message_index}"
|
|
100
|
+
raise ValueError(msg)
|
|
101
|
+
|
|
102
|
+
if merge_objects:
|
|
103
|
+
# Delegate to open_datasets and return the first result.
|
|
104
|
+
from tensogram_xarray.merge import open_datasets
|
|
105
|
+
|
|
106
|
+
datasets = open_datasets(
|
|
107
|
+
file_path,
|
|
108
|
+
dim_names=dim_names,
|
|
109
|
+
variable_key=variable_key,
|
|
110
|
+
verify_hash=verify_hash,
|
|
111
|
+
range_threshold=range_threshold,
|
|
112
|
+
storage_options=storage_options,
|
|
113
|
+
)
|
|
114
|
+
if not datasets:
|
|
115
|
+
return xr.Dataset()
|
|
116
|
+
return datasets[0]
|
|
117
|
+
|
|
118
|
+
store = TensogramDataStore(
|
|
119
|
+
file_path=file_path,
|
|
120
|
+
msg_index=message_index,
|
|
121
|
+
dim_names=dim_names,
|
|
122
|
+
variable_key=variable_key,
|
|
123
|
+
verify_hash=verify_hash,
|
|
124
|
+
range_threshold=range_threshold,
|
|
125
|
+
storage_options=storage_options,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
drop_set = set(drop_variables) if drop_variables else None
|
|
129
|
+
ds = store.build_dataset(drop_variables=drop_set)
|
|
130
|
+
ds.set_close(store.close)
|
|
131
|
+
return ds
|
|
132
|
+
|
|
133
|
+
def guess_can_open(self, filename_or_obj: str) -> bool: # type: ignore[override]
|
|
134
|
+
"""Return *True* for files with ``.tgm`` extension."""
|
|
135
|
+
try:
|
|
136
|
+
_, ext = os.path.splitext(filename_or_obj)
|
|
137
|
+
except (TypeError, AttributeError):
|
|
138
|
+
return False
|
|
139
|
+
return ext.lower() == ".tgm"
|