tensogram-anemoi 0.17.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tensogram_anemoi-0.17.0/.gitignore +22 -0
- tensogram_anemoi-0.17.0/PKG-INFO +24 -0
- tensogram_anemoi-0.17.0/pyproject.toml +62 -0
- tensogram_anemoi-0.17.0/src/tensogram_anemoi/__init__.py +4 -0
- tensogram_anemoi-0.17.0/src/tensogram_anemoi/output.py +336 -0
- tensogram_anemoi-0.17.0/tests/test_tensogram_output.py +571 -0
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
.claude/*
|
|
2
|
+
!.claude/commands/
|
|
3
|
+
.weave/
|
|
4
|
+
.sisyphus/
|
|
5
|
+
.coverage
|
|
6
|
+
*.dylib.dSYM/
|
|
7
|
+
**/target
|
|
8
|
+
**/pkg
|
|
9
|
+
**/build/
|
|
10
|
+
docs/book/
|
|
11
|
+
python/**/dist/
|
|
12
|
+
python/bindings/Cargo.lock
|
|
13
|
+
rust/tensogram-grib/Cargo.lock
|
|
14
|
+
rust/tensogram-netcdf/Cargo.lock
|
|
15
|
+
rust/tensogram-wasm/Cargo.lock
|
|
16
|
+
# Python virtualenv, caches, and maturin-installed extension modules
|
|
17
|
+
.venv/
|
|
18
|
+
**/__pycache__/
|
|
19
|
+
*.pyc
|
|
20
|
+
python/bindings/python/tensogram/tensogram*.so
|
|
21
|
+
# TODO do we want to have uv.locks ignored?
|
|
22
|
+
**/uv.lock
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: tensogram-anemoi
|
|
3
|
+
Version: 0.17.0
|
|
4
|
+
Summary: anemoi-inference output plugin for tensogram .tgm files
|
|
5
|
+
Project-URL: Homepage, https://sites.ecmwf.int/docs/tensogram/main
|
|
6
|
+
Project-URL: Repository, https://github.com/ecmwf/tensogram
|
|
7
|
+
Project-URL: Documentation, https://sites.ecmwf.int/docs/tensogram/main
|
|
8
|
+
Author-email: ECMWF <software@ecmwf.int>
|
|
9
|
+
License-Expression: Apache-2.0
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Topic :: Scientific/Engineering
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Atmospheric Science
|
|
14
|
+
Requires-Python: >=3.11
|
|
15
|
+
Requires-Dist: anemoi-inference<0.11,>=0.4
|
|
16
|
+
Requires-Dist: fsspec
|
|
17
|
+
Requires-Dist: numpy
|
|
18
|
+
Requires-Dist: tensogram<0.18,>=0.17.0
|
|
19
|
+
Provides-Extra: dev
|
|
20
|
+
Requires-Dist: pytest>=7.0; extra == 'dev'
|
|
21
|
+
Requires-Dist: ruff>=0.4; extra == 'dev'
|
|
22
|
+
Description-Content-Type: text/plain
|
|
23
|
+
|
|
24
|
+
anemoi-inference output plugin for tensogram .tgm files
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "tensogram-anemoi"
|
|
7
|
+
version = "0.17.0"
|
|
8
|
+
description = "anemoi-inference output plugin for tensogram .tgm files"
|
|
9
|
+
readme = {text = "anemoi-inference output plugin for tensogram .tgm files", content-type = "text/plain"}
|
|
10
|
+
requires-python = ">=3.11"
|
|
11
|
+
license = "Apache-2.0"
|
|
12
|
+
authors = [{name = "ECMWF", email = "software@ecmwf.int"}]
|
|
13
|
+
classifiers = [
|
|
14
|
+
"Development Status :: 4 - Beta",
|
|
15
|
+
"Programming Language :: Python :: 3",
|
|
16
|
+
"Topic :: Scientific/Engineering",
|
|
17
|
+
"Topic :: Scientific/Engineering :: Atmospheric Science",
|
|
18
|
+
]
|
|
19
|
+
dependencies = [
|
|
20
|
+
"tensogram>=0.17.0,<0.18",
|
|
21
|
+
"anemoi-inference>=0.4,<0.11",
|
|
22
|
+
"fsspec",
|
|
23
|
+
"numpy",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
[project.urls]
|
|
27
|
+
Homepage = "https://sites.ecmwf.int/docs/tensogram/main"
|
|
28
|
+
Repository = "https://github.com/ecmwf/tensogram"
|
|
29
|
+
Documentation = "https://sites.ecmwf.int/docs/tensogram/main"
|
|
30
|
+
|
|
31
|
+
[project.optional-dependencies]
|
|
32
|
+
dev = ["pytest>=7.0", "ruff>=0.4"]
|
|
33
|
+
|
|
34
|
+
[project.entry-points."anemoi.inference.outputs"]
|
|
35
|
+
tensogram = "tensogram_anemoi.output:TensogramOutput"
|
|
36
|
+
|
|
37
|
+
[tool.hatch.build.targets.wheel]
|
|
38
|
+
packages = ["src/tensogram_anemoi"]
|
|
39
|
+
|
|
40
|
+
[tool.ruff]
|
|
41
|
+
line-length = 99
|
|
42
|
+
target-version = "py311"
|
|
43
|
+
|
|
44
|
+
[tool.ruff.lint]
|
|
45
|
+
select = [
|
|
46
|
+
"E",
|
|
47
|
+
"W",
|
|
48
|
+
"F",
|
|
49
|
+
"I",
|
|
50
|
+
"N",
|
|
51
|
+
"UP",
|
|
52
|
+
"B",
|
|
53
|
+
"SIM",
|
|
54
|
+
"PT",
|
|
55
|
+
"RUF",
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
[tool.ruff.lint.isort]
|
|
59
|
+
known-third-party = ["numpy", "tensogram", "anemoi", "fsspec", "pytest"]
|
|
60
|
+
|
|
61
|
+
[tool.pytest.ini_options]
|
|
62
|
+
testpaths = ["tests"]
|
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
# (C) Copyright 2025 ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from functools import cached_property
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from anemoi.inference.context import Context
|
|
13
|
+
from anemoi.inference.decorators import main_argument
|
|
14
|
+
from anemoi.inference.output import Output
|
|
15
|
+
from anemoi.inference.types import ProcessorConfig
|
|
16
|
+
from anemoi.inference.types import State
|
|
17
|
+
|
|
18
|
+
LOG = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
# Written into base[i]["name"] for lat/lon objects. These names are
|
|
21
|
+
# deliberately NOT in tensogram-xarray's KNOWN_COORD_NAMES so that lat/lon
|
|
22
|
+
# objects share the same flat dimension as all field objects rather than each
|
|
23
|
+
# spawning its own dimension. The canonical names are preserved in the "anemoi"
|
|
24
|
+
# namespace for downstream consumers.
|
|
25
|
+
_COORD_NAME_MAP = {
|
|
26
|
+
"latitude": "grid_latitude",
|
|
27
|
+
"longitude": "grid_longitude",
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@main_argument("path")
|
|
32
|
+
class TensogramOutput(Output):
|
|
33
|
+
"""Tensogram output plugin for anemoi-inference.
|
|
34
|
+
|
|
35
|
+
Writes each forecast step as one tensogram message to a .tgm file.
|
|
36
|
+
Each message contains lat/lon coordinate
|
|
37
|
+
objects followed by one object per field (or one stacked object per
|
|
38
|
+
pressure-level parameter when ``stack_pressure_levels=True``).
|
|
39
|
+
|
|
40
|
+
Per-object metadata is stored under the ``"anemoi"`` namespace in CBOR.
|
|
41
|
+
Dimension-name hints are written into each base entry's ``dim_names``
|
|
42
|
+
key (per-object, axis-ordered) *and* into ``_extra_["dim_names"]``
|
|
43
|
+
(message-level size-to-name dict) so readers on either convention can
|
|
44
|
+
resolve meaningful names without callers passing ``dim_names`` explicitly.
|
|
45
|
+
|
|
46
|
+
Supports local paths and remote URLs (s3://, gs://, az://, ...) via fsspec.
|
|
47
|
+
Each step is encoded and written immediately -- no full-forecast buffering.
|
|
48
|
+
|
|
49
|
+
Pressure-level stacking
|
|
50
|
+
-----------------------
|
|
51
|
+
When ``stack_pressure_levels=True``, all fields sharing the same ``param``
|
|
52
|
+
are stacked into a single 2-D object of shape ``(n_grid, n_levels)``,
|
|
53
|
+
sorted by level in ascending order. The per-object metadata stores
|
|
54
|
+
``"levelist": [500, 850, ...]`` in the ``"mars"`` namespace instead of
|
|
55
|
+
the scalar ``"level"`` key used for unstacked fields.
|
|
56
|
+
|
|
57
|
+
Without stacking (default), every field is a separate 1-D object and the
|
|
58
|
+
scalar ``"level"`` key is always stored when it is present in the
|
|
59
|
+
checkpoint's GRIB keys.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
context: Context,
|
|
65
|
+
path: str,
|
|
66
|
+
*,
|
|
67
|
+
encoding: str = "none",
|
|
68
|
+
bits: int | None = None,
|
|
69
|
+
compression: str = "zstd",
|
|
70
|
+
dtype: str = "float32",
|
|
71
|
+
storage_options: dict | None = None,
|
|
72
|
+
stack_pressure_levels: bool = False,
|
|
73
|
+
variables: list[str] | None = None,
|
|
74
|
+
post_processors: list[ProcessorConfig] | None = None,
|
|
75
|
+
output_frequency: int | None = None,
|
|
76
|
+
write_initial_state: bool | None = None,
|
|
77
|
+
) -> None:
|
|
78
|
+
super().__init__(
|
|
79
|
+
context,
|
|
80
|
+
variables=variables,
|
|
81
|
+
post_processors=post_processors,
|
|
82
|
+
output_frequency=output_frequency,
|
|
83
|
+
write_initial_state=write_initial_state,
|
|
84
|
+
)
|
|
85
|
+
if encoding == "simple_packing" and bits is None:
|
|
86
|
+
raise ValueError("bits must be set when encoding='simple_packing'")
|
|
87
|
+
self.path = path
|
|
88
|
+
self.encoding = encoding
|
|
89
|
+
self.bits = bits
|
|
90
|
+
self.compression = compression
|
|
91
|
+
self.dtype = dtype
|
|
92
|
+
self.storage_options = storage_options or {}
|
|
93
|
+
self.stack_pressure_levels = stack_pressure_levels
|
|
94
|
+
self._handle = None
|
|
95
|
+
|
|
96
|
+
def __repr__(self) -> str:
|
|
97
|
+
return f"TensogramOutput({self.path})"
|
|
98
|
+
|
|
99
|
+
@cached_property
|
|
100
|
+
def _numpy_dtype(self) -> np.dtype:
|
|
101
|
+
return np.dtype(self.dtype)
|
|
102
|
+
|
|
103
|
+
# ------------------------------------------------------------------
|
|
104
|
+
# Lifecycle
|
|
105
|
+
# ------------------------------------------------------------------
|
|
106
|
+
|
|
107
|
+
def open(self, state: State) -> None:
|
|
108
|
+
import fsspec
|
|
109
|
+
|
|
110
|
+
path_str = str(self.path)
|
|
111
|
+
if "://" not in path_str:
|
|
112
|
+
Path(path_str).parent.mkdir(parents=True, exist_ok=True)
|
|
113
|
+
|
|
114
|
+
self._handle = fsspec.open(path_str, "wb", **self.storage_options).open()
|
|
115
|
+
LOG.info("TensogramOutput: writing to %s", path_str)
|
|
116
|
+
|
|
117
|
+
def write_initial_state(self, state: State) -> None:
|
|
118
|
+
"""Write the initial state, reducing multi-step fields to the last step."""
|
|
119
|
+
from anemoi.inference.state import reduce_state
|
|
120
|
+
|
|
121
|
+
state = reduce_state(state)
|
|
122
|
+
return super().write_initial_state(state)
|
|
123
|
+
|
|
124
|
+
def write_step(self, state: State) -> None:
|
|
125
|
+
"""Encode one forecast step as a tensogram message and write it immediately."""
|
|
126
|
+
if self._handle is None:
|
|
127
|
+
raise RuntimeError(f"{self!r}: write_step called before open() or after close()")
|
|
128
|
+
|
|
129
|
+
global_meta = {
|
|
130
|
+
"version": 3,
|
|
131
|
+
"base": [],
|
|
132
|
+
"_extra_": {},
|
|
133
|
+
}
|
|
134
|
+
descriptors_and_data = []
|
|
135
|
+
|
|
136
|
+
step_seconds = state["step"].total_seconds()
|
|
137
|
+
step_hours = int(step_seconds / 3600) if step_seconds % 3600 == 0 else step_seconds / 3600
|
|
138
|
+
base_dt = state["date"] - state["step"]
|
|
139
|
+
mars_extra = {
|
|
140
|
+
"date": base_dt.strftime("%Y%m%d"),
|
|
141
|
+
"time": base_dt.strftime("%H%M"),
|
|
142
|
+
"step": step_hours,
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
for coord_name, coord_arr in [
|
|
146
|
+
("latitude", state["latitudes"]),
|
|
147
|
+
("longitude", state["longitudes"]),
|
|
148
|
+
]:
|
|
149
|
+
arr = np.asarray(coord_arr, dtype=np.float64)
|
|
150
|
+
global_meta["base"].append(
|
|
151
|
+
{
|
|
152
|
+
"name": _COORD_NAME_MAP[coord_name],
|
|
153
|
+
"anemoi": {"variable": coord_name},
|
|
154
|
+
"dim_names": [coord_name],
|
|
155
|
+
}
|
|
156
|
+
)
|
|
157
|
+
descriptors_and_data.append(
|
|
158
|
+
(
|
|
159
|
+
{"type": "ntensor", "shape": list(arr.shape), "dtype": "float64"},
|
|
160
|
+
arr,
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if self.stack_pressure_levels:
|
|
165
|
+
self._add_fields_stacked(state, global_meta, descriptors_and_data, mars_extra)
|
|
166
|
+
else:
|
|
167
|
+
self._add_fields_flat(state, global_meta, descriptors_and_data, mars_extra)
|
|
168
|
+
|
|
169
|
+
n_grid = len(state["latitudes"])
|
|
170
|
+
dim_names_hint: dict[str, str] = {str(n_grid): "values"}
|
|
171
|
+
if self.stack_pressure_levels:
|
|
172
|
+
for _, arr in descriptors_and_data:
|
|
173
|
+
if arr.ndim == 2:
|
|
174
|
+
level_size = str(arr.shape[1])
|
|
175
|
+
if level_size not in dim_names_hint:
|
|
176
|
+
dim_names_hint[level_size] = "level"
|
|
177
|
+
global_meta["_extra_"]["dim_names"] = dim_names_hint
|
|
178
|
+
|
|
179
|
+
import tensogram
|
|
180
|
+
|
|
181
|
+
msg_bytes = tensogram.encode(global_meta, descriptors_and_data)
|
|
182
|
+
self._handle.write(msg_bytes)
|
|
183
|
+
|
|
184
|
+
def close(self) -> None:
|
|
185
|
+
"""Flush and close the output stream."""
|
|
186
|
+
if self._handle is not None:
|
|
187
|
+
try:
|
|
188
|
+
self._handle.flush()
|
|
189
|
+
except Exception:
|
|
190
|
+
LOG.warning(
|
|
191
|
+
"TensogramOutput: failed to flush output stream before close", exc_info=True
|
|
192
|
+
)
|
|
193
|
+
self._handle.close()
|
|
194
|
+
self._handle = None
|
|
195
|
+
|
|
196
|
+
# ------------------------------------------------------------------
|
|
197
|
+
# Field object builders
|
|
198
|
+
# ------------------------------------------------------------------
|
|
199
|
+
|
|
200
|
+
def _add_fields_flat(
|
|
201
|
+
self,
|
|
202
|
+
state: State,
|
|
203
|
+
global_meta: dict,
|
|
204
|
+
descriptors_and_data: list,
|
|
205
|
+
mars_extra: dict,
|
|
206
|
+
) -> None:
|
|
207
|
+
"""Add one object per field (default, no stacking)."""
|
|
208
|
+
for name, values in state["fields"].items():
|
|
209
|
+
if self.skip_variable(name):
|
|
210
|
+
continue
|
|
211
|
+
variable = self.typed_variables.get(name)
|
|
212
|
+
if variable is None:
|
|
213
|
+
LOG.warning(
|
|
214
|
+
"TensogramOutput: no typed variable for %r -- metadata will be incomplete",
|
|
215
|
+
name,
|
|
216
|
+
)
|
|
217
|
+
grib = getattr(variable, "grib_keys", {}) if variable else {}
|
|
218
|
+
base_entry, descriptor, arr = self._build_field_object(name, grib, values, mars_extra)
|
|
219
|
+
global_meta["base"].append(base_entry)
|
|
220
|
+
descriptors_and_data.append((descriptor, arr))
|
|
221
|
+
|
|
222
|
+
def _add_fields_stacked(
|
|
223
|
+
self,
|
|
224
|
+
state: State,
|
|
225
|
+
global_meta: dict,
|
|
226
|
+
descriptors_and_data: list,
|
|
227
|
+
mars_extra: dict,
|
|
228
|
+
) -> None:
|
|
229
|
+
"""Group pressure-level fields by param and stack; write others flat."""
|
|
230
|
+
pl_groups: dict[str, list[tuple[int, str, dict, np.ndarray]]] = {}
|
|
231
|
+
non_pl: list[tuple[str, dict, np.ndarray]] = []
|
|
232
|
+
|
|
233
|
+
for name, values in state["fields"].items():
|
|
234
|
+
if self.skip_variable(name):
|
|
235
|
+
continue
|
|
236
|
+
variable = self.typed_variables.get(name)
|
|
237
|
+
if variable is None:
|
|
238
|
+
LOG.warning(
|
|
239
|
+
"TensogramOutput: no typed variable for %r -- metadata will be incomplete",
|
|
240
|
+
name,
|
|
241
|
+
)
|
|
242
|
+
grib = getattr(variable, "grib_keys", {}) if variable else {}
|
|
243
|
+
if variable is not None and variable.is_pressure_level:
|
|
244
|
+
param = variable.param
|
|
245
|
+
level = variable.level
|
|
246
|
+
pl_groups.setdefault(param, []).append((level, name, grib, values))
|
|
247
|
+
else:
|
|
248
|
+
non_pl.append((name, grib, values))
|
|
249
|
+
|
|
250
|
+
if not pl_groups:
|
|
251
|
+
LOG.warning(
|
|
252
|
+
"TensogramOutput: stack_pressure_levels=True but no pressure-level fields found"
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
for name, grib, values in non_pl:
|
|
256
|
+
base_entry, descriptor, arr = self._build_field_object(name, grib, values, mars_extra)
|
|
257
|
+
global_meta["base"].append(base_entry)
|
|
258
|
+
descriptors_and_data.append((descriptor, arr))
|
|
259
|
+
|
|
260
|
+
for param in sorted(pl_groups):
|
|
261
|
+
group = sorted(pl_groups[param], key=lambda x: x[0])
|
|
262
|
+
base_entry, descriptor, arr = self._build_stacked_object(param, group, mars_extra)
|
|
263
|
+
global_meta["base"].append(base_entry)
|
|
264
|
+
descriptors_and_data.append((descriptor, arr))
|
|
265
|
+
|
|
266
|
+
# ------------------------------------------------------------------
|
|
267
|
+
# Helpers
|
|
268
|
+
# ------------------------------------------------------------------
|
|
269
|
+
|
|
270
|
+
def _prepare_array(self, values: np.ndarray) -> np.ndarray:
|
|
271
|
+
arr = np.asarray(values, dtype=self._numpy_dtype)
|
|
272
|
+
if self.encoding == "simple_packing":
|
|
273
|
+
arr = arr.astype(np.float64)
|
|
274
|
+
return arr
|
|
275
|
+
|
|
276
|
+
def _build_descriptor(self, arr: np.ndarray) -> dict:
|
|
277
|
+
descriptor = {
|
|
278
|
+
"type": "ntensor",
|
|
279
|
+
"shape": list(arr.shape),
|
|
280
|
+
"dtype": arr.dtype.name,
|
|
281
|
+
"encoding": self.encoding,
|
|
282
|
+
"compression": self.compression,
|
|
283
|
+
}
|
|
284
|
+
if self.encoding == "simple_packing" and self.bits is not None:
|
|
285
|
+
import tensogram
|
|
286
|
+
|
|
287
|
+
sp_params = tensogram.compute_packing_params(arr.ravel(), self.bits, 0)
|
|
288
|
+
descriptor.update(sp_params)
|
|
289
|
+
return descriptor
|
|
290
|
+
|
|
291
|
+
def _build_field_object(
|
|
292
|
+
self,
|
|
293
|
+
name: str,
|
|
294
|
+
grib: dict,
|
|
295
|
+
values: np.ndarray,
|
|
296
|
+
mars_extra: dict,
|
|
297
|
+
) -> tuple[dict, dict, np.ndarray]:
|
|
298
|
+
"""Build (base_entry, descriptor, array) for a single flat field object."""
|
|
299
|
+
mars = {**mars_extra, **grib}
|
|
300
|
+
base_entry: dict = {
|
|
301
|
+
"name": name,
|
|
302
|
+
"anemoi": {"variable": name},
|
|
303
|
+
"dim_names": ["values"],
|
|
304
|
+
}
|
|
305
|
+
if mars:
|
|
306
|
+
base_entry["mars"] = mars
|
|
307
|
+
arr = self._prepare_array(values)
|
|
308
|
+
return base_entry, self._build_descriptor(arr), arr
|
|
309
|
+
|
|
310
|
+
def _build_stacked_object(
|
|
311
|
+
self,
|
|
312
|
+
param: str,
|
|
313
|
+
group: list[tuple[int, str, dict, np.ndarray]],
|
|
314
|
+
mars_extra: dict,
|
|
315
|
+
) -> tuple[dict, dict, np.ndarray]:
|
|
316
|
+
"""Build (base_entry, descriptor, array) for a stacked pressure-level object."""
|
|
317
|
+
levels = [item[0] for item in group]
|
|
318
|
+
first_grib = group[0][2]
|
|
319
|
+
|
|
320
|
+
arrays = [self._prepare_array(item[3]) for item in group]
|
|
321
|
+
stacked = np.column_stack(arrays)
|
|
322
|
+
|
|
323
|
+
mars = {
|
|
324
|
+
**mars_extra,
|
|
325
|
+
**{k: v for k, v in first_grib.items() if k != "level"},
|
|
326
|
+
"levelist": levels,
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
base_entry: dict = {
|
|
330
|
+
"name": param,
|
|
331
|
+
"anemoi": {"variable": param},
|
|
332
|
+
"dim_names": ["values", "level"],
|
|
333
|
+
}
|
|
334
|
+
if mars:
|
|
335
|
+
base_entry["mars"] = mars
|
|
336
|
+
return base_entry, self._build_descriptor(stacked), stacked
|
|
@@ -0,0 +1,571 @@
|
|
|
1
|
+
# (C) Copyright 2025 ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
|
|
6
|
+
"""Tests for TensogramOutput."""
|
|
7
|
+
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from datetime import timedelta
|
|
10
|
+
from types import SimpleNamespace
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pytest
|
|
14
|
+
|
|
15
|
+
from tensogram_anemoi.output import TensogramOutput
|
|
16
|
+
|
|
17
|
+
tensogram = pytest.importorskip("tensogram", reason="tensogram not installed")
|
|
18
|
+
|
|
19
|
+
# ---------------------------------------------------------------------------
|
|
20
|
+
# Helpers
|
|
21
|
+
# ---------------------------------------------------------------------------
|
|
22
|
+
|
|
23
|
+
N_GRID = 100
|
|
24
|
+
FIELD_NAMES = ["2t", "10u", "10v"]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _make_variable(param=None, levtype=None, level=None):
|
|
28
|
+
grib = {"param": param} if param is not None else {}
|
|
29
|
+
if levtype is not None:
|
|
30
|
+
grib["levtype"] = levtype
|
|
31
|
+
if level is not None:
|
|
32
|
+
grib["level"] = level
|
|
33
|
+
is_pl = levtype is not None and level is not None
|
|
34
|
+
return SimpleNamespace(
|
|
35
|
+
grib_keys=grib,
|
|
36
|
+
is_computed_forcing=False,
|
|
37
|
+
is_pressure_level=is_pl,
|
|
38
|
+
param=param,
|
|
39
|
+
level=level,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _make_context(field_names=FIELD_NAMES):
|
|
44
|
+
typed_variables = {name: _make_variable(param=name) for name in field_names}
|
|
45
|
+
|
|
46
|
+
checkpoint = SimpleNamespace(
|
|
47
|
+
typed_variables=typed_variables,
|
|
48
|
+
)
|
|
49
|
+
context = SimpleNamespace(
|
|
50
|
+
checkpoint=checkpoint,
|
|
51
|
+
reference_date=datetime(2024, 1, 1),
|
|
52
|
+
write_initial_state=False,
|
|
53
|
+
output_frequency=None,
|
|
54
|
+
typed_variables={},
|
|
55
|
+
)
|
|
56
|
+
return context
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _make_state(step_hours=1, field_names=FIELD_NAMES, n_grid=N_GRID, seed=0):
|
|
60
|
+
rng = np.random.default_rng(seed)
|
|
61
|
+
date = datetime(2024, 1, 1) + timedelta(hours=step_hours)
|
|
62
|
+
return {
|
|
63
|
+
"date": date,
|
|
64
|
+
"step": timedelta(hours=step_hours),
|
|
65
|
+
"latitudes": np.linspace(-90, 90, n_grid, dtype=np.float64),
|
|
66
|
+
"longitudes": np.linspace(0, 360, n_grid, dtype=np.float64),
|
|
67
|
+
"fields": {name: rng.random(n_grid).astype(np.float32) for name in field_names},
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# ---------------------------------------------------------------------------
|
|
72
|
+
# Tests
|
|
73
|
+
# ---------------------------------------------------------------------------
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def test_write_step_before_open_raises(tmp_path):
|
|
77
|
+
"""write_step raises RuntimeError when called before open()."""
|
|
78
|
+
context = _make_context()
|
|
79
|
+
output = TensogramOutput(context, str(tmp_path / "never.tgm"))
|
|
80
|
+
with pytest.raises(RuntimeError, match="open"):
|
|
81
|
+
output.write_step(_make_state())
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def test_write_initial_state_reduces_multistep(tmp_path):
|
|
85
|
+
"""write_initial_state reduces a multi-step field array to its last step."""
|
|
86
|
+
from unittest.mock import patch
|
|
87
|
+
|
|
88
|
+
path = tmp_path / "init.tgm"
|
|
89
|
+
context = _make_context(["2t"])
|
|
90
|
+
output = TensogramOutput(context, str(path), write_initial_state=True)
|
|
91
|
+
|
|
92
|
+
rng = np.random.default_rng(1)
|
|
93
|
+
multi_step_values = rng.random((3, N_GRID)).astype(np.float32)
|
|
94
|
+
state = {
|
|
95
|
+
"date": datetime(2024, 1, 1),
|
|
96
|
+
"step": timedelta(0),
|
|
97
|
+
"latitudes": np.linspace(-90, 90, N_GRID, dtype=np.float64),
|
|
98
|
+
"longitudes": np.linspace(0, 360, N_GRID, dtype=np.float64),
|
|
99
|
+
"fields": {"2t": multi_step_values},
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
written_states = []
|
|
103
|
+
|
|
104
|
+
def capture_write_step(s):
|
|
105
|
+
written_states.append(s)
|
|
106
|
+
|
|
107
|
+
output.open(state)
|
|
108
|
+
with patch.object(output, "write_step", side_effect=capture_write_step):
|
|
109
|
+
output.write_initial_state(state)
|
|
110
|
+
output.close()
|
|
111
|
+
|
|
112
|
+
assert written_states, "write_step was not called by write_initial_state"
|
|
113
|
+
written_field = written_states[0]["fields"]["2t"]
|
|
114
|
+
np.testing.assert_array_equal(written_field, multi_step_values[-1])
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def test_write_and_read_local(tmp_path):
|
|
118
|
+
"""Write 3 steps to a local .tgm file and verify round-trip."""
|
|
119
|
+
path = tmp_path / "forecast.tgm"
|
|
120
|
+
context = _make_context()
|
|
121
|
+
output = TensogramOutput(context, str(path))
|
|
122
|
+
|
|
123
|
+
states = [_make_state(h) for h in range(1, 4)]
|
|
124
|
+
output.open(states[0])
|
|
125
|
+
for state in states:
|
|
126
|
+
output.write_step(state)
|
|
127
|
+
output.close()
|
|
128
|
+
|
|
129
|
+
assert path.exists()
|
|
130
|
+
|
|
131
|
+
tgm_file = tensogram.TensogramFile.open(str(path))
|
|
132
|
+
assert len(tgm_file) == 3
|
|
133
|
+
|
|
134
|
+
for i, msg in enumerate(tgm_file):
|
|
135
|
+
meta, objects = msg
|
|
136
|
+
|
|
137
|
+
assert len(objects) == 5
|
|
138
|
+
|
|
139
|
+
assert meta.base[0]["name"] == "grid_latitude"
|
|
140
|
+
assert meta.base[1]["name"] == "grid_longitude"
|
|
141
|
+
assert meta.base[0]["anemoi"]["variable"] == "latitude"
|
|
142
|
+
assert meta.base[1]["anemoi"]["variable"] == "longitude"
|
|
143
|
+
lat_desc, lat_arr = objects[0]
|
|
144
|
+
lon_desc, lon_arr = objects[1]
|
|
145
|
+
assert lat_desc.dtype == "float64"
|
|
146
|
+
assert lon_desc.dtype == "float64"
|
|
147
|
+
np.testing.assert_allclose(lat_arr, states[i]["latitudes"])
|
|
148
|
+
np.testing.assert_allclose(lon_arr, states[i]["longitudes"])
|
|
149
|
+
|
|
150
|
+
expected_step = int(states[i]["step"].total_seconds() / 3600)
|
|
151
|
+
base_dt = states[i]["date"] - states[i]["step"]
|
|
152
|
+
expected_date = base_dt.strftime("%Y%m%d")
|
|
153
|
+
expected_time = base_dt.strftime("%H%M")
|
|
154
|
+
for j, name in enumerate(FIELD_NAMES):
|
|
155
|
+
obj_idx = 2 + j
|
|
156
|
+
assert meta.base[obj_idx]["name"] == name
|
|
157
|
+
assert meta.base[obj_idx]["anemoi"]["variable"] == name
|
|
158
|
+
assert meta.base[obj_idx]["mars"]["param"] == name
|
|
159
|
+
assert meta.base[obj_idx]["mars"]["step"] == expected_step
|
|
160
|
+
assert meta.base[obj_idx]["mars"]["date"] == expected_date
|
|
161
|
+
assert meta.base[obj_idx]["mars"]["time"] == expected_time
|
|
162
|
+
_, field_arr = objects[obj_idx]
|
|
163
|
+
np.testing.assert_allclose(field_arr, states[i]["fields"][name], rtol=1e-6)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def test_variable_filter(tmp_path):
|
|
167
|
+
"""Only the selected variable and coordinates are written."""
|
|
168
|
+
path = tmp_path / "filtered.tgm"
|
|
169
|
+
context = _make_context()
|
|
170
|
+
output = TensogramOutput(context, str(path), variables=["2t"])
|
|
171
|
+
|
|
172
|
+
state = _make_state()
|
|
173
|
+
output.open(state)
|
|
174
|
+
output.write_step(state)
|
|
175
|
+
output.close()
|
|
176
|
+
|
|
177
|
+
tgm_file = tensogram.TensogramFile.open(str(path))
|
|
178
|
+
msg = tgm_file[0]
|
|
179
|
+
meta, objects = msg
|
|
180
|
+
|
|
181
|
+
assert len(objects) == 3
|
|
182
|
+
assert meta.base[2]["anemoi"]["variable"] == "2t"
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def test_simple_packing_encoding(tmp_path):
|
|
186
|
+
"""simple_packing round-trip stays within expected quantisation error."""
|
|
187
|
+
path = tmp_path / "packed.tgm"
|
|
188
|
+
context = _make_context(["2t"])
|
|
189
|
+
output = TensogramOutput(
|
|
190
|
+
context, str(path), encoding="simple_packing", bits=16, compression="zstd"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
rng = np.random.default_rng(42)
|
|
194
|
+
values = (rng.random(N_GRID) * 50 + 250).astype(np.float32)
|
|
195
|
+
state = {
|
|
196
|
+
"date": datetime(2024, 1, 1, 1),
|
|
197
|
+
"step": timedelta(hours=1),
|
|
198
|
+
"latitudes": np.linspace(-90, 90, N_GRID, dtype=np.float64),
|
|
199
|
+
"longitudes": np.linspace(0, 360, N_GRID, dtype=np.float64),
|
|
200
|
+
"fields": {"2t": values},
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
output.open(state)
|
|
204
|
+
output.write_step(state)
|
|
205
|
+
output.close()
|
|
206
|
+
|
|
207
|
+
tgm_file = tensogram.TensogramFile.open(str(path))
|
|
208
|
+
meta, objects = tgm_file[0]
|
|
209
|
+
_, decoded = objects[2]
|
|
210
|
+
|
|
211
|
+
np.testing.assert_allclose(decoded, values, atol=0.002)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def test_simple_packing_requires_bits(tmp_path):
|
|
215
|
+
"""encoding='simple_packing' without bits raises ValueError immediately."""
|
|
216
|
+
context = _make_context(["2t"])
|
|
217
|
+
with pytest.raises(ValueError, match="bits must be set"):
|
|
218
|
+
TensogramOutput(context, str(tmp_path / "out.tgm"), encoding="simple_packing")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def test_mars_metadata_forwarded(tmp_path):
|
|
222
|
+
"""grib_keys appear in per-object 'mars' namespace, following tensogram-grib convention."""
|
|
223
|
+
path = tmp_path / "levs.tgm"
|
|
224
|
+
typed_variables = {
|
|
225
|
+
"t500": _make_variable(param="t", levtype="pl", level=500),
|
|
226
|
+
}
|
|
227
|
+
checkpoint = SimpleNamespace(typed_variables=typed_variables)
|
|
228
|
+
context = SimpleNamespace(
|
|
229
|
+
checkpoint=checkpoint,
|
|
230
|
+
reference_date=datetime(2024, 1, 1),
|
|
231
|
+
write_initial_state=False,
|
|
232
|
+
output_frequency=None,
|
|
233
|
+
typed_variables={},
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
output = TensogramOutput(context, str(path))
|
|
237
|
+
state = {
|
|
238
|
+
"date": datetime(2024, 1, 1, 6),
|
|
239
|
+
"step": timedelta(hours=6),
|
|
240
|
+
"latitudes": np.zeros(10, dtype=np.float64),
|
|
241
|
+
"longitudes": np.zeros(10, dtype=np.float64),
|
|
242
|
+
"fields": {"t500": np.ones(10, dtype=np.float32)},
|
|
243
|
+
}
|
|
244
|
+
output.open(state)
|
|
245
|
+
output.write_step(state)
|
|
246
|
+
output.close()
|
|
247
|
+
|
|
248
|
+
tgm_file = tensogram.TensogramFile.open(str(path))
|
|
249
|
+
meta, _ = tgm_file[0]
|
|
250
|
+
mars = meta.base[2]["mars"]
|
|
251
|
+
assert mars["param"] == "t"
|
|
252
|
+
assert mars["levtype"] == "pl"
|
|
253
|
+
assert mars["level"] == 500
|
|
254
|
+
assert mars["step"] == 6
|
|
255
|
+
assert mars["date"] == "20240101"
|
|
256
|
+
assert mars["time"] == "0000"
|
|
257
|
+
assert meta.base[2]["anemoi"]["variable"] == "t500"
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def test_remote_write_via_memory_fs():
|
|
261
|
+
"""Write to fsspec memory filesystem and read back valid tensogram messages."""
|
|
262
|
+
import fsspec
|
|
263
|
+
|
|
264
|
+
url = "memory://test_forecast.tgm"
|
|
265
|
+
context = _make_context()
|
|
266
|
+
output = TensogramOutput(context, url)
|
|
267
|
+
|
|
268
|
+
states = [_make_state(h) for h in range(1, 3)]
|
|
269
|
+
output.open(states[0])
|
|
270
|
+
for state in states:
|
|
271
|
+
output.write_step(state)
|
|
272
|
+
output.close()
|
|
273
|
+
|
|
274
|
+
fs = fsspec.filesystem("memory")
|
|
275
|
+
raw = fs.open(url, "rb").read()
|
|
276
|
+
|
|
277
|
+
messages = tensogram.scan(raw)
|
|
278
|
+
assert len(messages) == 2
|
|
279
|
+
|
|
280
|
+
for i, (offset, length) in enumerate(messages):
|
|
281
|
+
msg_bytes = raw[offset : offset + length]
|
|
282
|
+
meta, objects = tensogram.decode(msg_bytes)
|
|
283
|
+
assert len(objects) == 5
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def test_dtype_float64(tmp_path):
|
|
287
|
+
"""dtype=float64 stores field arrays as float64."""
|
|
288
|
+
path = tmp_path / "f64.tgm"
|
|
289
|
+
context = _make_context(["2t"])
|
|
290
|
+
output = TensogramOutput(context, str(path), dtype="float64")
|
|
291
|
+
|
|
292
|
+
state = _make_state(field_names=["2t"])
|
|
293
|
+
output.open(state)
|
|
294
|
+
output.write_step(state)
|
|
295
|
+
output.close()
|
|
296
|
+
|
|
297
|
+
tgm_file = tensogram.TensogramFile.open(str(path))
|
|
298
|
+
meta, objects = tgm_file[0]
|
|
299
|
+
desc, arr = objects[2]
|
|
300
|
+
assert desc.dtype == "float64"
|
|
301
|
+
assert arr.dtype == np.float64
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def test_close_is_idempotent(tmp_path):
|
|
305
|
+
"""close() can be called multiple times without error."""
|
|
306
|
+
path = tmp_path / "idem.tgm"
|
|
307
|
+
context = _make_context()
|
|
308
|
+
output = TensogramOutput(context, str(path))
|
|
309
|
+
output.open(_make_state())
|
|
310
|
+
output.write_step(_make_state())
|
|
311
|
+
output.close()
|
|
312
|
+
output.close()
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def test_dim_names_in_metadata(tmp_path):
|
|
316
|
+
"""Dim-name hints are written into both _extra_ (message-level) and
|
|
317
|
+
per-object base[i]['dim_names'] so xarray readers on either convention
|
|
318
|
+
resolve meaningful dim names without explicit kwargs."""
|
|
319
|
+
path = tmp_path / "dimnames.tgm"
|
|
320
|
+
context = _make_context()
|
|
321
|
+
output = TensogramOutput(context, str(path))
|
|
322
|
+
state = _make_state()
|
|
323
|
+
output.open(state)
|
|
324
|
+
output.write_step(state)
|
|
325
|
+
output.close()
|
|
326
|
+
|
|
327
|
+
tgm_file = tensogram.TensogramFile.open(str(path))
|
|
328
|
+
meta, _ = tgm_file[0]
|
|
329
|
+
dim_names = meta.extra["dim_names"]
|
|
330
|
+
assert str(N_GRID) in dim_names
|
|
331
|
+
assert dim_names[str(N_GRID)] == "values"
|
|
332
|
+
|
|
333
|
+
# Coord entries pin their axis to the canonical coord name.
|
|
334
|
+
coord_entries = {
|
|
335
|
+
entry.get("anemoi", {}).get("variable"): entry
|
|
336
|
+
for entry in meta.base
|
|
337
|
+
if entry.get("anemoi", {}).get("variable") in {"latitude", "longitude"}
|
|
338
|
+
}
|
|
339
|
+
assert coord_entries.get("latitude", {}).get("dim_names") == ["latitude"]
|
|
340
|
+
assert coord_entries.get("longitude", {}).get("dim_names") == ["longitude"]
|
|
341
|
+
|
|
342
|
+
# Flat field entries use the 'values' dim.
|
|
343
|
+
field_entries = [
|
|
344
|
+
entry
|
|
345
|
+
for entry in meta.base
|
|
346
|
+
if entry.get("anemoi", {}).get("variable") not in {"latitude", "longitude"}
|
|
347
|
+
]
|
|
348
|
+
assert field_entries, "expected at least one field entry"
|
|
349
|
+
for entry in field_entries:
|
|
350
|
+
assert entry.get("dim_names") == ["values"], entry
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def test_stacked_dim_names_in_metadata(tmp_path):
|
|
354
|
+
"""Stacked fields write size-based 'values'/'level' hints into _extra_
|
|
355
|
+
AND a ['values', 'level'] per-object list on each stacked base entry."""
|
|
356
|
+
path = tmp_path / "stacked_dims.tgm"
|
|
357
|
+
context = _make_pl_context(params=["t"], levels=[500, 850, 1000])
|
|
358
|
+
output = TensogramOutput(context, str(path), stack_pressure_levels=True)
|
|
359
|
+
state = _make_pl_state(params=["t"], levels=[500, 850, 1000])
|
|
360
|
+
output.open(state)
|
|
361
|
+
output.write_step(state)
|
|
362
|
+
output.close()
|
|
363
|
+
|
|
364
|
+
tgm_file = tensogram.TensogramFile.open(str(path))
|
|
365
|
+
meta, _ = tgm_file[0]
|
|
366
|
+
dim_names = meta.extra["dim_names"]
|
|
367
|
+
assert str(N_GRID) in dim_names
|
|
368
|
+
assert dim_names[str(N_GRID)] == "values"
|
|
369
|
+
assert str(3) in dim_names
|
|
370
|
+
assert dim_names[str(3)] == "level"
|
|
371
|
+
|
|
372
|
+
stacked_entries = [
|
|
373
|
+
entry for entry in meta.base if entry.get("anemoi", {}).get("variable") == "t"
|
|
374
|
+
]
|
|
375
|
+
assert stacked_entries, "expected stacked pressure-level entry"
|
|
376
|
+
for entry in stacked_entries:
|
|
377
|
+
assert entry.get("dim_names") == ["values", "level"], entry
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
# ---------------------------------------------------------------------------
|
|
381
|
+
# Pressure-level stacking tests
|
|
382
|
+
# ---------------------------------------------------------------------------
|
|
383
|
+
|
|
384
|
+
PL_LEVELS = [500, 850, 1000]
|
|
385
|
+
PL_PARAMS = ["t", "u"]
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def _make_pl_context(params=PL_PARAMS, levels=PL_LEVELS, extra_fields=None):
|
|
389
|
+
"""Context with pressure-level variables plus optional non-PL extras."""
|
|
390
|
+
typed_variables = {}
|
|
391
|
+
for param in params:
|
|
392
|
+
for level in levels:
|
|
393
|
+
name = f"{param}{level}"
|
|
394
|
+
typed_variables[name] = _make_variable(param=param, levtype="pl", level=level)
|
|
395
|
+
for name in extra_fields or []:
|
|
396
|
+
typed_variables[name] = _make_variable(param=name)
|
|
397
|
+
|
|
398
|
+
checkpoint = SimpleNamespace(typed_variables=typed_variables)
|
|
399
|
+
return SimpleNamespace(
|
|
400
|
+
checkpoint=checkpoint,
|
|
401
|
+
reference_date=datetime(2024, 1, 1),
|
|
402
|
+
write_initial_state=False,
|
|
403
|
+
output_frequency=None,
|
|
404
|
+
typed_variables={},
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def _make_pl_state(params=PL_PARAMS, levels=PL_LEVELS, extra_fields=None, n_grid=N_GRID, seed=0):
|
|
409
|
+
rng = np.random.default_rng(seed)
|
|
410
|
+
fields = {}
|
|
411
|
+
for param in params:
|
|
412
|
+
for level in levels:
|
|
413
|
+
fields[f"{param}{level}"] = rng.random(n_grid).astype(np.float32)
|
|
414
|
+
for name in extra_fields or []:
|
|
415
|
+
fields[name] = rng.random(n_grid).astype(np.float32)
|
|
416
|
+
return {
|
|
417
|
+
"date": datetime(2024, 1, 1, 1),
|
|
418
|
+
"step": timedelta(hours=1),
|
|
419
|
+
"latitudes": np.linspace(-90, 90, n_grid, dtype=np.float64),
|
|
420
|
+
"longitudes": np.linspace(0, 360, n_grid, dtype=np.float64),
|
|
421
|
+
"fields": fields,
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def test_stack_object_count(tmp_path):
|
|
426
|
+
"""With stacking: one object per param, not per level."""
|
|
427
|
+
path = tmp_path / "stacked.tgm"
|
|
428
|
+
context = _make_pl_context()
|
|
429
|
+
output = TensogramOutput(context, str(path), stack_pressure_levels=True)
|
|
430
|
+
|
|
431
|
+
state = _make_pl_state()
|
|
432
|
+
output.open(state)
|
|
433
|
+
output.write_step(state)
|
|
434
|
+
output.close()
|
|
435
|
+
|
|
436
|
+
tgm_file = tensogram.TensogramFile.open(str(path))
|
|
437
|
+
meta, objects = tgm_file[0]
|
|
438
|
+
|
|
439
|
+
assert len(objects) == 2 + len(PL_PARAMS)
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
def test_stack_shape(tmp_path):
|
|
443
|
+
"""Stacked objects have shape (n_grid, n_levels) -- grid axis first."""
|
|
444
|
+
path = tmp_path / "stacked.tgm"
|
|
445
|
+
context = _make_pl_context()
|
|
446
|
+
output = TensogramOutput(context, str(path), stack_pressure_levels=True)
|
|
447
|
+
|
|
448
|
+
state = _make_pl_state()
|
|
449
|
+
output.open(state)
|
|
450
|
+
output.write_step(state)
|
|
451
|
+
output.close()
|
|
452
|
+
|
|
453
|
+
tgm_file = tensogram.TensogramFile.open(str(path))
|
|
454
|
+
meta, objects = tgm_file[0]
|
|
455
|
+
|
|
456
|
+
for obj_idx in range(2, 2 + len(PL_PARAMS)):
|
|
457
|
+
desc, arr = objects[obj_idx]
|
|
458
|
+
assert arr.shape == (N_GRID, len(PL_LEVELS)), (
|
|
459
|
+
f"expected ({N_GRID}, {len(PL_LEVELS)}), got {arr.shape}"
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def test_stack_levels_metadata(tmp_path):
|
|
464
|
+
"""Stacked objects store 'levelist' sorted ascending; no scalar 'level'."""
|
|
465
|
+
path = tmp_path / "stacked.tgm"
|
|
466
|
+
context = _make_pl_context()
|
|
467
|
+
output = TensogramOutput(context, str(path), stack_pressure_levels=True)
|
|
468
|
+
|
|
469
|
+
state = _make_pl_state()
|
|
470
|
+
output.open(state)
|
|
471
|
+
output.write_step(state)
|
|
472
|
+
output.close()
|
|
473
|
+
|
|
474
|
+
tgm_file = tensogram.TensogramFile.open(str(path))
|
|
475
|
+
meta, _ = tgm_file[0]
|
|
476
|
+
|
|
477
|
+
for obj_idx in range(2, 2 + len(PL_PARAMS)):
|
|
478
|
+
entry = meta.base[obj_idx]
|
|
479
|
+
anemoi = entry["anemoi"]
|
|
480
|
+
mars = entry["mars"]
|
|
481
|
+
assert "levelist" not in anemoi
|
|
482
|
+
assert "level" not in anemoi
|
|
483
|
+
assert "level" not in mars
|
|
484
|
+
assert mars["levelist"] == sorted(PL_LEVELS)
|
|
485
|
+
assert mars["levtype"] == "pl"
|
|
486
|
+
assert "name" in entry
|
|
487
|
+
assert entry["name"] == mars["param"]
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def test_stack_values_round_trip(tmp_path):
|
|
491
|
+
"""Stacked values decode to the same data as the input, in level-sorted order."""
|
|
492
|
+
path = tmp_path / "stacked.tgm"
|
|
493
|
+
context = _make_pl_context(params=["t"], levels=[850, 500, 1000])
|
|
494
|
+
output = TensogramOutput(context, str(path), stack_pressure_levels=True)
|
|
495
|
+
|
|
496
|
+
rng = np.random.default_rng(7)
|
|
497
|
+
fields = {f"t{lv}": rng.random(N_GRID).astype(np.float32) for lv in [850, 500, 1000]}
|
|
498
|
+
state = {
|
|
499
|
+
"date": datetime(2024, 1, 1, 1),
|
|
500
|
+
"step": timedelta(hours=1),
|
|
501
|
+
"latitudes": np.linspace(-90, 90, N_GRID, dtype=np.float64),
|
|
502
|
+
"longitudes": np.linspace(0, 360, N_GRID, dtype=np.float64),
|
|
503
|
+
"fields": fields,
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
output.open(state)
|
|
507
|
+
output.write_step(state)
|
|
508
|
+
output.close()
|
|
509
|
+
|
|
510
|
+
tgm_file = tensogram.TensogramFile.open(str(path))
|
|
511
|
+
meta, objects = tgm_file[0]
|
|
512
|
+
|
|
513
|
+
mars = meta.base[2]["mars"]
|
|
514
|
+
assert mars["levelist"] == [500, 850, 1000]
|
|
515
|
+
|
|
516
|
+
_, arr = objects[2]
|
|
517
|
+
np.testing.assert_allclose(arr[:, 0], fields["t500"], rtol=1e-6)
|
|
518
|
+
np.testing.assert_allclose(arr[:, 1], fields["t850"], rtol=1e-6)
|
|
519
|
+
np.testing.assert_allclose(arr[:, 2], fields["t1000"], rtol=1e-6)
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def test_stack_non_pl_fields_written_flat(tmp_path):
|
|
523
|
+
"""Non-PL fields are written as individual objects even with stacking enabled."""
|
|
524
|
+
path = tmp_path / "mixed.tgm"
|
|
525
|
+
context = _make_pl_context(params=["t"], levels=[500, 850], extra_fields=["2t"])
|
|
526
|
+
output = TensogramOutput(context, str(path), stack_pressure_levels=True)
|
|
527
|
+
|
|
528
|
+
state = _make_pl_state(params=["t"], levels=[500, 850], extra_fields=["2t"])
|
|
529
|
+
output.open(state)
|
|
530
|
+
output.write_step(state)
|
|
531
|
+
output.close()
|
|
532
|
+
|
|
533
|
+
tgm_file = tensogram.TensogramFile.open(str(path))
|
|
534
|
+
meta, objects = tgm_file[0]
|
|
535
|
+
|
|
536
|
+
assert len(objects) == 4
|
|
537
|
+
|
|
538
|
+
flat_entries = [
|
|
539
|
+
meta.base[i]["anemoi"]
|
|
540
|
+
for i in range(2, len(objects))
|
|
541
|
+
if meta.base[i]["anemoi"].get("variable") == "2t"
|
|
542
|
+
]
|
|
543
|
+
assert len(flat_entries) == 1
|
|
544
|
+
assert "levels" not in flat_entries[0]
|
|
545
|
+
assert "level" not in flat_entries[0]
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
def test_no_stack_level_metadata_correct(tmp_path):
|
|
549
|
+
"""Without stacking, each PL field stores scalar 'level' and 'levtype'."""
|
|
550
|
+
path = tmp_path / "flat_pl.tgm"
|
|
551
|
+
context = _make_pl_context(params=["t"], levels=[500, 850])
|
|
552
|
+
output = TensogramOutput(context, str(path), stack_pressure_levels=False)
|
|
553
|
+
|
|
554
|
+
state = _make_pl_state(params=["t"], levels=[500, 850])
|
|
555
|
+
output.open(state)
|
|
556
|
+
output.write_step(state)
|
|
557
|
+
output.close()
|
|
558
|
+
|
|
559
|
+
tgm_file = tensogram.TensogramFile.open(str(path))
|
|
560
|
+
meta, objects = tgm_file[0]
|
|
561
|
+
|
|
562
|
+
assert len(objects) == 4
|
|
563
|
+
|
|
564
|
+
for obj_idx in range(2, 4):
|
|
565
|
+
mars = meta.base[obj_idx]["mars"]
|
|
566
|
+
anemoi = meta.base[obj_idx]["anemoi"]
|
|
567
|
+
assert "level" in mars
|
|
568
|
+
assert "levtype" in mars
|
|
569
|
+
assert mars["levtype"] == "pl"
|
|
570
|
+
assert mars["param"] == "t"
|
|
571
|
+
assert "levels" not in anemoi
|