tensogram-anemoi 0.17.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,22 @@
1
+ .claude/*
2
+ !.claude/commands/
3
+ .weave/
4
+ .sisyphus/
5
+ .coverage
6
+ *.dylib.dSYM/
7
+ **/target
8
+ **/pkg
9
+ **/build/
10
+ docs/book/
11
+ python/**/dist/
12
+ python/bindings/Cargo.lock
13
+ rust/tensogram-grib/Cargo.lock
14
+ rust/tensogram-netcdf/Cargo.lock
15
+ rust/tensogram-wasm/Cargo.lock
16
+ # Python virtualenv, caches, and maturin-installed extension modules
17
+ .venv/
18
+ **/__pycache__/
19
+ *.pyc
20
+ python/bindings/python/tensogram/tensogram*.so
21
+ # TODO do we want to have uv.locks ignored?
22
+ **/uv.lock
@@ -0,0 +1,24 @@
1
+ Metadata-Version: 2.4
2
+ Name: tensogram-anemoi
3
+ Version: 0.17.0
4
+ Summary: anemoi-inference output plugin for tensogram .tgm files
5
+ Project-URL: Homepage, https://sites.ecmwf.int/docs/tensogram/main
6
+ Project-URL: Repository, https://github.com/ecmwf/tensogram
7
+ Project-URL: Documentation, https://sites.ecmwf.int/docs/tensogram/main
8
+ Author-email: ECMWF <software@ecmwf.int>
9
+ License-Expression: Apache-2.0
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Topic :: Scientific/Engineering
13
+ Classifier: Topic :: Scientific/Engineering :: Atmospheric Science
14
+ Requires-Python: >=3.11
15
+ Requires-Dist: anemoi-inference<0.11,>=0.4
16
+ Requires-Dist: fsspec
17
+ Requires-Dist: numpy
18
+ Requires-Dist: tensogram<0.18,>=0.17.0
19
+ Provides-Extra: dev
20
+ Requires-Dist: pytest>=7.0; extra == 'dev'
21
+ Requires-Dist: ruff>=0.4; extra == 'dev'
22
+ Description-Content-Type: text/plain
23
+
24
+ anemoi-inference output plugin for tensogram .tgm files
@@ -0,0 +1,62 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "tensogram-anemoi"
7
+ version = "0.17.0"
8
+ description = "anemoi-inference output plugin for tensogram .tgm files"
9
+ readme = {text = "anemoi-inference output plugin for tensogram .tgm files", content-type = "text/plain"}
10
+ requires-python = ">=3.11"
11
+ license = "Apache-2.0"
12
+ authors = [{name = "ECMWF", email = "software@ecmwf.int"}]
13
+ classifiers = [
14
+ "Development Status :: 4 - Beta",
15
+ "Programming Language :: Python :: 3",
16
+ "Topic :: Scientific/Engineering",
17
+ "Topic :: Scientific/Engineering :: Atmospheric Science",
18
+ ]
19
+ dependencies = [
20
+ "tensogram>=0.17.0,<0.18",
21
+ "anemoi-inference>=0.4,<0.11",
22
+ "fsspec",
23
+ "numpy",
24
+ ]
25
+
26
+ [project.urls]
27
+ Homepage = "https://sites.ecmwf.int/docs/tensogram/main"
28
+ Repository = "https://github.com/ecmwf/tensogram"
29
+ Documentation = "https://sites.ecmwf.int/docs/tensogram/main"
30
+
31
+ [project.optional-dependencies]
32
+ dev = ["pytest>=7.0", "ruff>=0.4"]
33
+
34
+ [project.entry-points."anemoi.inference.outputs"]
35
+ tensogram = "tensogram_anemoi.output:TensogramOutput"
36
+
37
+ [tool.hatch.build.targets.wheel]
38
+ packages = ["src/tensogram_anemoi"]
39
+
40
+ [tool.ruff]
41
+ line-length = 99
42
+ target-version = "py311"
43
+
44
+ [tool.ruff.lint]
45
+ select = [
46
+ "E",
47
+ "W",
48
+ "F",
49
+ "I",
50
+ "N",
51
+ "UP",
52
+ "B",
53
+ "SIM",
54
+ "PT",
55
+ "RUF",
56
+ ]
57
+
58
+ [tool.ruff.lint.isort]
59
+ known-third-party = ["numpy", "tensogram", "anemoi", "fsspec", "pytest"]
60
+
61
+ [tool.pytest.ini_options]
62
+ testpaths = ["tests"]
@@ -0,0 +1,4 @@
1
+ # (C) Copyright 2025 ECMWF.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
@@ -0,0 +1,336 @@
1
+ # (C) Copyright 2025 ECMWF.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+
6
+ import logging
7
+ from functools import cached_property
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+
12
+ from anemoi.inference.context import Context
13
+ from anemoi.inference.decorators import main_argument
14
+ from anemoi.inference.output import Output
15
+ from anemoi.inference.types import ProcessorConfig
16
+ from anemoi.inference.types import State
17
+
18
+ LOG = logging.getLogger(__name__)
19
+
20
+ # Written into base[i]["name"] for lat/lon objects. These names are
21
+ # deliberately NOT in tensogram-xarray's KNOWN_COORD_NAMES so that lat/lon
22
+ # objects share the same flat dimension as all field objects rather than each
23
+ # spawning its own dimension. The canonical names are preserved in the "anemoi"
24
+ # namespace for downstream consumers.
25
+ _COORD_NAME_MAP = {
26
+ "latitude": "grid_latitude",
27
+ "longitude": "grid_longitude",
28
+ }
29
+
30
+
31
+ @main_argument("path")
32
+ class TensogramOutput(Output):
33
+ """Tensogram output plugin for anemoi-inference.
34
+
35
+ Writes each forecast step as one tensogram message to a .tgm file.
36
+ Each message contains lat/lon coordinate
37
+ objects followed by one object per field (or one stacked object per
38
+ pressure-level parameter when ``stack_pressure_levels=True``).
39
+
40
+ Per-object metadata is stored under the ``"anemoi"`` namespace in CBOR.
41
+ Dimension-name hints are written into each base entry's ``dim_names``
42
+ key (per-object, axis-ordered) *and* into ``_extra_["dim_names"]``
43
+ (message-level size-to-name dict) so readers on either convention can
44
+ resolve meaningful names without callers passing ``dim_names`` explicitly.
45
+
46
+ Supports local paths and remote URLs (s3://, gs://, az://, ...) via fsspec.
47
+ Each step is encoded and written immediately -- no full-forecast buffering.
48
+
49
+ Pressure-level stacking
50
+ -----------------------
51
+ When ``stack_pressure_levels=True``, all fields sharing the same ``param``
52
+ are stacked into a single 2-D object of shape ``(n_grid, n_levels)``,
53
+ sorted by level in ascending order. The per-object metadata stores
54
+ ``"levelist": [500, 850, ...]`` in the ``"mars"`` namespace instead of
55
+ the scalar ``"level"`` key used for unstacked fields.
56
+
57
+ Without stacking (default), every field is a separate 1-D object and the
58
+ scalar ``"level"`` key is always stored when it is present in the
59
+ checkpoint's GRIB keys.
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ context: Context,
65
+ path: str,
66
+ *,
67
+ encoding: str = "none",
68
+ bits: int | None = None,
69
+ compression: str = "zstd",
70
+ dtype: str = "float32",
71
+ storage_options: dict | None = None,
72
+ stack_pressure_levels: bool = False,
73
+ variables: list[str] | None = None,
74
+ post_processors: list[ProcessorConfig] | None = None,
75
+ output_frequency: int | None = None,
76
+ write_initial_state: bool | None = None,
77
+ ) -> None:
78
+ super().__init__(
79
+ context,
80
+ variables=variables,
81
+ post_processors=post_processors,
82
+ output_frequency=output_frequency,
83
+ write_initial_state=write_initial_state,
84
+ )
85
+ if encoding == "simple_packing" and bits is None:
86
+ raise ValueError("bits must be set when encoding='simple_packing'")
87
+ self.path = path
88
+ self.encoding = encoding
89
+ self.bits = bits
90
+ self.compression = compression
91
+ self.dtype = dtype
92
+ self.storage_options = storage_options or {}
93
+ self.stack_pressure_levels = stack_pressure_levels
94
+ self._handle = None
95
+
96
+ def __repr__(self) -> str:
97
+ return f"TensogramOutput({self.path})"
98
+
99
+ @cached_property
100
+ def _numpy_dtype(self) -> np.dtype:
101
+ return np.dtype(self.dtype)
102
+
103
+ # ------------------------------------------------------------------
104
+ # Lifecycle
105
+ # ------------------------------------------------------------------
106
+
107
+ def open(self, state: State) -> None:
108
+ import fsspec
109
+
110
+ path_str = str(self.path)
111
+ if "://" not in path_str:
112
+ Path(path_str).parent.mkdir(parents=True, exist_ok=True)
113
+
114
+ self._handle = fsspec.open(path_str, "wb", **self.storage_options).open()
115
+ LOG.info("TensogramOutput: writing to %s", path_str)
116
+
117
+ def write_initial_state(self, state: State) -> None:
118
+ """Write the initial state, reducing multi-step fields to the last step."""
119
+ from anemoi.inference.state import reduce_state
120
+
121
+ state = reduce_state(state)
122
+ return super().write_initial_state(state)
123
+
124
+ def write_step(self, state: State) -> None:
125
+ """Encode one forecast step as a tensogram message and write it immediately."""
126
+ if self._handle is None:
127
+ raise RuntimeError(f"{self!r}: write_step called before open() or after close()")
128
+
129
+ global_meta = {
130
+ "version": 3,
131
+ "base": [],
132
+ "_extra_": {},
133
+ }
134
+ descriptors_and_data = []
135
+
136
+ step_seconds = state["step"].total_seconds()
137
+ step_hours = int(step_seconds / 3600) if step_seconds % 3600 == 0 else step_seconds / 3600
138
+ base_dt = state["date"] - state["step"]
139
+ mars_extra = {
140
+ "date": base_dt.strftime("%Y%m%d"),
141
+ "time": base_dt.strftime("%H%M"),
142
+ "step": step_hours,
143
+ }
144
+
145
+ for coord_name, coord_arr in [
146
+ ("latitude", state["latitudes"]),
147
+ ("longitude", state["longitudes"]),
148
+ ]:
149
+ arr = np.asarray(coord_arr, dtype=np.float64)
150
+ global_meta["base"].append(
151
+ {
152
+ "name": _COORD_NAME_MAP[coord_name],
153
+ "anemoi": {"variable": coord_name},
154
+ "dim_names": [coord_name],
155
+ }
156
+ )
157
+ descriptors_and_data.append(
158
+ (
159
+ {"type": "ntensor", "shape": list(arr.shape), "dtype": "float64"},
160
+ arr,
161
+ )
162
+ )
163
+
164
+ if self.stack_pressure_levels:
165
+ self._add_fields_stacked(state, global_meta, descriptors_and_data, mars_extra)
166
+ else:
167
+ self._add_fields_flat(state, global_meta, descriptors_and_data, mars_extra)
168
+
169
+ n_grid = len(state["latitudes"])
170
+ dim_names_hint: dict[str, str] = {str(n_grid): "values"}
171
+ if self.stack_pressure_levels:
172
+ for _, arr in descriptors_and_data:
173
+ if arr.ndim == 2:
174
+ level_size = str(arr.shape[1])
175
+ if level_size not in dim_names_hint:
176
+ dim_names_hint[level_size] = "level"
177
+ global_meta["_extra_"]["dim_names"] = dim_names_hint
178
+
179
+ import tensogram
180
+
181
+ msg_bytes = tensogram.encode(global_meta, descriptors_and_data)
182
+ self._handle.write(msg_bytes)
183
+
184
+ def close(self) -> None:
185
+ """Flush and close the output stream."""
186
+ if self._handle is not None:
187
+ try:
188
+ self._handle.flush()
189
+ except Exception:
190
+ LOG.warning(
191
+ "TensogramOutput: failed to flush output stream before close", exc_info=True
192
+ )
193
+ self._handle.close()
194
+ self._handle = None
195
+
196
+ # ------------------------------------------------------------------
197
+ # Field object builders
198
+ # ------------------------------------------------------------------
199
+
200
+ def _add_fields_flat(
201
+ self,
202
+ state: State,
203
+ global_meta: dict,
204
+ descriptors_and_data: list,
205
+ mars_extra: dict,
206
+ ) -> None:
207
+ """Add one object per field (default, no stacking)."""
208
+ for name, values in state["fields"].items():
209
+ if self.skip_variable(name):
210
+ continue
211
+ variable = self.typed_variables.get(name)
212
+ if variable is None:
213
+ LOG.warning(
214
+ "TensogramOutput: no typed variable for %r -- metadata will be incomplete",
215
+ name,
216
+ )
217
+ grib = getattr(variable, "grib_keys", {}) if variable else {}
218
+ base_entry, descriptor, arr = self._build_field_object(name, grib, values, mars_extra)
219
+ global_meta["base"].append(base_entry)
220
+ descriptors_and_data.append((descriptor, arr))
221
+
222
+ def _add_fields_stacked(
223
+ self,
224
+ state: State,
225
+ global_meta: dict,
226
+ descriptors_and_data: list,
227
+ mars_extra: dict,
228
+ ) -> None:
229
+ """Group pressure-level fields by param and stack; write others flat."""
230
+ pl_groups: dict[str, list[tuple[int, str, dict, np.ndarray]]] = {}
231
+ non_pl: list[tuple[str, dict, np.ndarray]] = []
232
+
233
+ for name, values in state["fields"].items():
234
+ if self.skip_variable(name):
235
+ continue
236
+ variable = self.typed_variables.get(name)
237
+ if variable is None:
238
+ LOG.warning(
239
+ "TensogramOutput: no typed variable for %r -- metadata will be incomplete",
240
+ name,
241
+ )
242
+ grib = getattr(variable, "grib_keys", {}) if variable else {}
243
+ if variable is not None and variable.is_pressure_level:
244
+ param = variable.param
245
+ level = variable.level
246
+ pl_groups.setdefault(param, []).append((level, name, grib, values))
247
+ else:
248
+ non_pl.append((name, grib, values))
249
+
250
+ if not pl_groups:
251
+ LOG.warning(
252
+ "TensogramOutput: stack_pressure_levels=True but no pressure-level fields found"
253
+ )
254
+
255
+ for name, grib, values in non_pl:
256
+ base_entry, descriptor, arr = self._build_field_object(name, grib, values, mars_extra)
257
+ global_meta["base"].append(base_entry)
258
+ descriptors_and_data.append((descriptor, arr))
259
+
260
+ for param in sorted(pl_groups):
261
+ group = sorted(pl_groups[param], key=lambda x: x[0])
262
+ base_entry, descriptor, arr = self._build_stacked_object(param, group, mars_extra)
263
+ global_meta["base"].append(base_entry)
264
+ descriptors_and_data.append((descriptor, arr))
265
+
266
+ # ------------------------------------------------------------------
267
+ # Helpers
268
+ # ------------------------------------------------------------------
269
+
270
+ def _prepare_array(self, values: np.ndarray) -> np.ndarray:
271
+ arr = np.asarray(values, dtype=self._numpy_dtype)
272
+ if self.encoding == "simple_packing":
273
+ arr = arr.astype(np.float64)
274
+ return arr
275
+
276
+ def _build_descriptor(self, arr: np.ndarray) -> dict:
277
+ descriptor = {
278
+ "type": "ntensor",
279
+ "shape": list(arr.shape),
280
+ "dtype": arr.dtype.name,
281
+ "encoding": self.encoding,
282
+ "compression": self.compression,
283
+ }
284
+ if self.encoding == "simple_packing" and self.bits is not None:
285
+ import tensogram
286
+
287
+ sp_params = tensogram.compute_packing_params(arr.ravel(), self.bits, 0)
288
+ descriptor.update(sp_params)
289
+ return descriptor
290
+
291
+ def _build_field_object(
292
+ self,
293
+ name: str,
294
+ grib: dict,
295
+ values: np.ndarray,
296
+ mars_extra: dict,
297
+ ) -> tuple[dict, dict, np.ndarray]:
298
+ """Build (base_entry, descriptor, array) for a single flat field object."""
299
+ mars = {**mars_extra, **grib}
300
+ base_entry: dict = {
301
+ "name": name,
302
+ "anemoi": {"variable": name},
303
+ "dim_names": ["values"],
304
+ }
305
+ if mars:
306
+ base_entry["mars"] = mars
307
+ arr = self._prepare_array(values)
308
+ return base_entry, self._build_descriptor(arr), arr
309
+
310
+ def _build_stacked_object(
311
+ self,
312
+ param: str,
313
+ group: list[tuple[int, str, dict, np.ndarray]],
314
+ mars_extra: dict,
315
+ ) -> tuple[dict, dict, np.ndarray]:
316
+ """Build (base_entry, descriptor, array) for a stacked pressure-level object."""
317
+ levels = [item[0] for item in group]
318
+ first_grib = group[0][2]
319
+
320
+ arrays = [self._prepare_array(item[3]) for item in group]
321
+ stacked = np.column_stack(arrays)
322
+
323
+ mars = {
324
+ **mars_extra,
325
+ **{k: v for k, v in first_grib.items() if k != "level"},
326
+ "levelist": levels,
327
+ }
328
+
329
+ base_entry: dict = {
330
+ "name": param,
331
+ "anemoi": {"variable": param},
332
+ "dim_names": ["values", "level"],
333
+ }
334
+ if mars:
335
+ base_entry["mars"] = mars
336
+ return base_entry, self._build_descriptor(stacked), stacked
@@ -0,0 +1,571 @@
1
+ # (C) Copyright 2025 ECMWF.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+
6
+ """Tests for TensogramOutput."""
7
+
8
+ from datetime import datetime
9
+ from datetime import timedelta
10
+ from types import SimpleNamespace
11
+
12
+ import numpy as np
13
+ import pytest
14
+
15
+ from tensogram_anemoi.output import TensogramOutput
16
+
17
+ tensogram = pytest.importorskip("tensogram", reason="tensogram not installed")
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Helpers
21
+ # ---------------------------------------------------------------------------
22
+
23
+ N_GRID = 100
24
+ FIELD_NAMES = ["2t", "10u", "10v"]
25
+
26
+
27
+ def _make_variable(param=None, levtype=None, level=None):
28
+ grib = {"param": param} if param is not None else {}
29
+ if levtype is not None:
30
+ grib["levtype"] = levtype
31
+ if level is not None:
32
+ grib["level"] = level
33
+ is_pl = levtype is not None and level is not None
34
+ return SimpleNamespace(
35
+ grib_keys=grib,
36
+ is_computed_forcing=False,
37
+ is_pressure_level=is_pl,
38
+ param=param,
39
+ level=level,
40
+ )
41
+
42
+
43
+ def _make_context(field_names=FIELD_NAMES):
44
+ typed_variables = {name: _make_variable(param=name) for name in field_names}
45
+
46
+ checkpoint = SimpleNamespace(
47
+ typed_variables=typed_variables,
48
+ )
49
+ context = SimpleNamespace(
50
+ checkpoint=checkpoint,
51
+ reference_date=datetime(2024, 1, 1),
52
+ write_initial_state=False,
53
+ output_frequency=None,
54
+ typed_variables={},
55
+ )
56
+ return context
57
+
58
+
59
+ def _make_state(step_hours=1, field_names=FIELD_NAMES, n_grid=N_GRID, seed=0):
60
+ rng = np.random.default_rng(seed)
61
+ date = datetime(2024, 1, 1) + timedelta(hours=step_hours)
62
+ return {
63
+ "date": date,
64
+ "step": timedelta(hours=step_hours),
65
+ "latitudes": np.linspace(-90, 90, n_grid, dtype=np.float64),
66
+ "longitudes": np.linspace(0, 360, n_grid, dtype=np.float64),
67
+ "fields": {name: rng.random(n_grid).astype(np.float32) for name in field_names},
68
+ }
69
+
70
+
71
+ # ---------------------------------------------------------------------------
72
+ # Tests
73
+ # ---------------------------------------------------------------------------
74
+
75
+
76
+ def test_write_step_before_open_raises(tmp_path):
77
+ """write_step raises RuntimeError when called before open()."""
78
+ context = _make_context()
79
+ output = TensogramOutput(context, str(tmp_path / "never.tgm"))
80
+ with pytest.raises(RuntimeError, match="open"):
81
+ output.write_step(_make_state())
82
+
83
+
84
+ def test_write_initial_state_reduces_multistep(tmp_path):
85
+ """write_initial_state reduces a multi-step field array to its last step."""
86
+ from unittest.mock import patch
87
+
88
+ path = tmp_path / "init.tgm"
89
+ context = _make_context(["2t"])
90
+ output = TensogramOutput(context, str(path), write_initial_state=True)
91
+
92
+ rng = np.random.default_rng(1)
93
+ multi_step_values = rng.random((3, N_GRID)).astype(np.float32)
94
+ state = {
95
+ "date": datetime(2024, 1, 1),
96
+ "step": timedelta(0),
97
+ "latitudes": np.linspace(-90, 90, N_GRID, dtype=np.float64),
98
+ "longitudes": np.linspace(0, 360, N_GRID, dtype=np.float64),
99
+ "fields": {"2t": multi_step_values},
100
+ }
101
+
102
+ written_states = []
103
+
104
+ def capture_write_step(s):
105
+ written_states.append(s)
106
+
107
+ output.open(state)
108
+ with patch.object(output, "write_step", side_effect=capture_write_step):
109
+ output.write_initial_state(state)
110
+ output.close()
111
+
112
+ assert written_states, "write_step was not called by write_initial_state"
113
+ written_field = written_states[0]["fields"]["2t"]
114
+ np.testing.assert_array_equal(written_field, multi_step_values[-1])
115
+
116
+
117
+ def test_write_and_read_local(tmp_path):
118
+ """Write 3 steps to a local .tgm file and verify round-trip."""
119
+ path = tmp_path / "forecast.tgm"
120
+ context = _make_context()
121
+ output = TensogramOutput(context, str(path))
122
+
123
+ states = [_make_state(h) for h in range(1, 4)]
124
+ output.open(states[0])
125
+ for state in states:
126
+ output.write_step(state)
127
+ output.close()
128
+
129
+ assert path.exists()
130
+
131
+ tgm_file = tensogram.TensogramFile.open(str(path))
132
+ assert len(tgm_file) == 3
133
+
134
+ for i, msg in enumerate(tgm_file):
135
+ meta, objects = msg
136
+
137
+ assert len(objects) == 5
138
+
139
+ assert meta.base[0]["name"] == "grid_latitude"
140
+ assert meta.base[1]["name"] == "grid_longitude"
141
+ assert meta.base[0]["anemoi"]["variable"] == "latitude"
142
+ assert meta.base[1]["anemoi"]["variable"] == "longitude"
143
+ lat_desc, lat_arr = objects[0]
144
+ lon_desc, lon_arr = objects[1]
145
+ assert lat_desc.dtype == "float64"
146
+ assert lon_desc.dtype == "float64"
147
+ np.testing.assert_allclose(lat_arr, states[i]["latitudes"])
148
+ np.testing.assert_allclose(lon_arr, states[i]["longitudes"])
149
+
150
+ expected_step = int(states[i]["step"].total_seconds() / 3600)
151
+ base_dt = states[i]["date"] - states[i]["step"]
152
+ expected_date = base_dt.strftime("%Y%m%d")
153
+ expected_time = base_dt.strftime("%H%M")
154
+ for j, name in enumerate(FIELD_NAMES):
155
+ obj_idx = 2 + j
156
+ assert meta.base[obj_idx]["name"] == name
157
+ assert meta.base[obj_idx]["anemoi"]["variable"] == name
158
+ assert meta.base[obj_idx]["mars"]["param"] == name
159
+ assert meta.base[obj_idx]["mars"]["step"] == expected_step
160
+ assert meta.base[obj_idx]["mars"]["date"] == expected_date
161
+ assert meta.base[obj_idx]["mars"]["time"] == expected_time
162
+ _, field_arr = objects[obj_idx]
163
+ np.testing.assert_allclose(field_arr, states[i]["fields"][name], rtol=1e-6)
164
+
165
+
166
+ def test_variable_filter(tmp_path):
167
+ """Only the selected variable and coordinates are written."""
168
+ path = tmp_path / "filtered.tgm"
169
+ context = _make_context()
170
+ output = TensogramOutput(context, str(path), variables=["2t"])
171
+
172
+ state = _make_state()
173
+ output.open(state)
174
+ output.write_step(state)
175
+ output.close()
176
+
177
+ tgm_file = tensogram.TensogramFile.open(str(path))
178
+ msg = tgm_file[0]
179
+ meta, objects = msg
180
+
181
+ assert len(objects) == 3
182
+ assert meta.base[2]["anemoi"]["variable"] == "2t"
183
+
184
+
185
+ def test_simple_packing_encoding(tmp_path):
186
+ """simple_packing round-trip stays within expected quantisation error."""
187
+ path = tmp_path / "packed.tgm"
188
+ context = _make_context(["2t"])
189
+ output = TensogramOutput(
190
+ context, str(path), encoding="simple_packing", bits=16, compression="zstd"
191
+ )
192
+
193
+ rng = np.random.default_rng(42)
194
+ values = (rng.random(N_GRID) * 50 + 250).astype(np.float32)
195
+ state = {
196
+ "date": datetime(2024, 1, 1, 1),
197
+ "step": timedelta(hours=1),
198
+ "latitudes": np.linspace(-90, 90, N_GRID, dtype=np.float64),
199
+ "longitudes": np.linspace(0, 360, N_GRID, dtype=np.float64),
200
+ "fields": {"2t": values},
201
+ }
202
+
203
+ output.open(state)
204
+ output.write_step(state)
205
+ output.close()
206
+
207
+ tgm_file = tensogram.TensogramFile.open(str(path))
208
+ meta, objects = tgm_file[0]
209
+ _, decoded = objects[2]
210
+
211
+ np.testing.assert_allclose(decoded, values, atol=0.002)
212
+
213
+
214
+ def test_simple_packing_requires_bits(tmp_path):
215
+ """encoding='simple_packing' without bits raises ValueError immediately."""
216
+ context = _make_context(["2t"])
217
+ with pytest.raises(ValueError, match="bits must be set"):
218
+ TensogramOutput(context, str(tmp_path / "out.tgm"), encoding="simple_packing")
219
+
220
+
221
+ def test_mars_metadata_forwarded(tmp_path):
222
+ """grib_keys appear in per-object 'mars' namespace, following tensogram-grib convention."""
223
+ path = tmp_path / "levs.tgm"
224
+ typed_variables = {
225
+ "t500": _make_variable(param="t", levtype="pl", level=500),
226
+ }
227
+ checkpoint = SimpleNamespace(typed_variables=typed_variables)
228
+ context = SimpleNamespace(
229
+ checkpoint=checkpoint,
230
+ reference_date=datetime(2024, 1, 1),
231
+ write_initial_state=False,
232
+ output_frequency=None,
233
+ typed_variables={},
234
+ )
235
+
236
+ output = TensogramOutput(context, str(path))
237
+ state = {
238
+ "date": datetime(2024, 1, 1, 6),
239
+ "step": timedelta(hours=6),
240
+ "latitudes": np.zeros(10, dtype=np.float64),
241
+ "longitudes": np.zeros(10, dtype=np.float64),
242
+ "fields": {"t500": np.ones(10, dtype=np.float32)},
243
+ }
244
+ output.open(state)
245
+ output.write_step(state)
246
+ output.close()
247
+
248
+ tgm_file = tensogram.TensogramFile.open(str(path))
249
+ meta, _ = tgm_file[0]
250
+ mars = meta.base[2]["mars"]
251
+ assert mars["param"] == "t"
252
+ assert mars["levtype"] == "pl"
253
+ assert mars["level"] == 500
254
+ assert mars["step"] == 6
255
+ assert mars["date"] == "20240101"
256
+ assert mars["time"] == "0000"
257
+ assert meta.base[2]["anemoi"]["variable"] == "t500"
258
+
259
+
260
+ def test_remote_write_via_memory_fs():
261
+ """Write to fsspec memory filesystem and read back valid tensogram messages."""
262
+ import fsspec
263
+
264
+ url = "memory://test_forecast.tgm"
265
+ context = _make_context()
266
+ output = TensogramOutput(context, url)
267
+
268
+ states = [_make_state(h) for h in range(1, 3)]
269
+ output.open(states[0])
270
+ for state in states:
271
+ output.write_step(state)
272
+ output.close()
273
+
274
+ fs = fsspec.filesystem("memory")
275
+ raw = fs.open(url, "rb").read()
276
+
277
+ messages = tensogram.scan(raw)
278
+ assert len(messages) == 2
279
+
280
+ for i, (offset, length) in enumerate(messages):
281
+ msg_bytes = raw[offset : offset + length]
282
+ meta, objects = tensogram.decode(msg_bytes)
283
+ assert len(objects) == 5
284
+
285
+
286
+ def test_dtype_float64(tmp_path):
287
+ """dtype=float64 stores field arrays as float64."""
288
+ path = tmp_path / "f64.tgm"
289
+ context = _make_context(["2t"])
290
+ output = TensogramOutput(context, str(path), dtype="float64")
291
+
292
+ state = _make_state(field_names=["2t"])
293
+ output.open(state)
294
+ output.write_step(state)
295
+ output.close()
296
+
297
+ tgm_file = tensogram.TensogramFile.open(str(path))
298
+ meta, objects = tgm_file[0]
299
+ desc, arr = objects[2]
300
+ assert desc.dtype == "float64"
301
+ assert arr.dtype == np.float64
302
+
303
+
304
+ def test_close_is_idempotent(tmp_path):
305
+ """close() can be called multiple times without error."""
306
+ path = tmp_path / "idem.tgm"
307
+ context = _make_context()
308
+ output = TensogramOutput(context, str(path))
309
+ output.open(_make_state())
310
+ output.write_step(_make_state())
311
+ output.close()
312
+ output.close()
313
+
314
+
315
+ def test_dim_names_in_metadata(tmp_path):
316
+ """Dim-name hints are written into both _extra_ (message-level) and
317
+ per-object base[i]['dim_names'] so xarray readers on either convention
318
+ resolve meaningful dim names without explicit kwargs."""
319
+ path = tmp_path / "dimnames.tgm"
320
+ context = _make_context()
321
+ output = TensogramOutput(context, str(path))
322
+ state = _make_state()
323
+ output.open(state)
324
+ output.write_step(state)
325
+ output.close()
326
+
327
+ tgm_file = tensogram.TensogramFile.open(str(path))
328
+ meta, _ = tgm_file[0]
329
+ dim_names = meta.extra["dim_names"]
330
+ assert str(N_GRID) in dim_names
331
+ assert dim_names[str(N_GRID)] == "values"
332
+
333
+ # Coord entries pin their axis to the canonical coord name.
334
+ coord_entries = {
335
+ entry.get("anemoi", {}).get("variable"): entry
336
+ for entry in meta.base
337
+ if entry.get("anemoi", {}).get("variable") in {"latitude", "longitude"}
338
+ }
339
+ assert coord_entries.get("latitude", {}).get("dim_names") == ["latitude"]
340
+ assert coord_entries.get("longitude", {}).get("dim_names") == ["longitude"]
341
+
342
+ # Flat field entries use the 'values' dim.
343
+ field_entries = [
344
+ entry
345
+ for entry in meta.base
346
+ if entry.get("anemoi", {}).get("variable") not in {"latitude", "longitude"}
347
+ ]
348
+ assert field_entries, "expected at least one field entry"
349
+ for entry in field_entries:
350
+ assert entry.get("dim_names") == ["values"], entry
351
+
352
+
353
+ def test_stacked_dim_names_in_metadata(tmp_path):
354
+ """Stacked fields write size-based 'values'/'level' hints into _extra_
355
+ AND a ['values', 'level'] per-object list on each stacked base entry."""
356
+ path = tmp_path / "stacked_dims.tgm"
357
+ context = _make_pl_context(params=["t"], levels=[500, 850, 1000])
358
+ output = TensogramOutput(context, str(path), stack_pressure_levels=True)
359
+ state = _make_pl_state(params=["t"], levels=[500, 850, 1000])
360
+ output.open(state)
361
+ output.write_step(state)
362
+ output.close()
363
+
364
+ tgm_file = tensogram.TensogramFile.open(str(path))
365
+ meta, _ = tgm_file[0]
366
+ dim_names = meta.extra["dim_names"]
367
+ assert str(N_GRID) in dim_names
368
+ assert dim_names[str(N_GRID)] == "values"
369
+ assert str(3) in dim_names
370
+ assert dim_names[str(3)] == "level"
371
+
372
+ stacked_entries = [
373
+ entry for entry in meta.base if entry.get("anemoi", {}).get("variable") == "t"
374
+ ]
375
+ assert stacked_entries, "expected stacked pressure-level entry"
376
+ for entry in stacked_entries:
377
+ assert entry.get("dim_names") == ["values", "level"], entry
378
+
379
+
380
+ # ---------------------------------------------------------------------------
381
+ # Pressure-level stacking tests
382
+ # ---------------------------------------------------------------------------
383
+
384
+ PL_LEVELS = [500, 850, 1000]
385
+ PL_PARAMS = ["t", "u"]
386
+
387
+
388
+ def _make_pl_context(params=PL_PARAMS, levels=PL_LEVELS, extra_fields=None):
389
+ """Context with pressure-level variables plus optional non-PL extras."""
390
+ typed_variables = {}
391
+ for param in params:
392
+ for level in levels:
393
+ name = f"{param}{level}"
394
+ typed_variables[name] = _make_variable(param=param, levtype="pl", level=level)
395
+ for name in extra_fields or []:
396
+ typed_variables[name] = _make_variable(param=name)
397
+
398
+ checkpoint = SimpleNamespace(typed_variables=typed_variables)
399
+ return SimpleNamespace(
400
+ checkpoint=checkpoint,
401
+ reference_date=datetime(2024, 1, 1),
402
+ write_initial_state=False,
403
+ output_frequency=None,
404
+ typed_variables={},
405
+ )
406
+
407
+
408
+ def _make_pl_state(params=PL_PARAMS, levels=PL_LEVELS, extra_fields=None, n_grid=N_GRID, seed=0):
409
+ rng = np.random.default_rng(seed)
410
+ fields = {}
411
+ for param in params:
412
+ for level in levels:
413
+ fields[f"{param}{level}"] = rng.random(n_grid).astype(np.float32)
414
+ for name in extra_fields or []:
415
+ fields[name] = rng.random(n_grid).astype(np.float32)
416
+ return {
417
+ "date": datetime(2024, 1, 1, 1),
418
+ "step": timedelta(hours=1),
419
+ "latitudes": np.linspace(-90, 90, n_grid, dtype=np.float64),
420
+ "longitudes": np.linspace(0, 360, n_grid, dtype=np.float64),
421
+ "fields": fields,
422
+ }
423
+
424
+
425
+ def test_stack_object_count(tmp_path):
426
+ """With stacking: one object per param, not per level."""
427
+ path = tmp_path / "stacked.tgm"
428
+ context = _make_pl_context()
429
+ output = TensogramOutput(context, str(path), stack_pressure_levels=True)
430
+
431
+ state = _make_pl_state()
432
+ output.open(state)
433
+ output.write_step(state)
434
+ output.close()
435
+
436
+ tgm_file = tensogram.TensogramFile.open(str(path))
437
+ meta, objects = tgm_file[0]
438
+
439
+ assert len(objects) == 2 + len(PL_PARAMS)
440
+
441
+
442
+ def test_stack_shape(tmp_path):
443
+ """Stacked objects have shape (n_grid, n_levels) -- grid axis first."""
444
+ path = tmp_path / "stacked.tgm"
445
+ context = _make_pl_context()
446
+ output = TensogramOutput(context, str(path), stack_pressure_levels=True)
447
+
448
+ state = _make_pl_state()
449
+ output.open(state)
450
+ output.write_step(state)
451
+ output.close()
452
+
453
+ tgm_file = tensogram.TensogramFile.open(str(path))
454
+ meta, objects = tgm_file[0]
455
+
456
+ for obj_idx in range(2, 2 + len(PL_PARAMS)):
457
+ desc, arr = objects[obj_idx]
458
+ assert arr.shape == (N_GRID, len(PL_LEVELS)), (
459
+ f"expected ({N_GRID}, {len(PL_LEVELS)}), got {arr.shape}"
460
+ )
461
+
462
+
463
+ def test_stack_levels_metadata(tmp_path):
464
+ """Stacked objects store 'levelist' sorted ascending; no scalar 'level'."""
465
+ path = tmp_path / "stacked.tgm"
466
+ context = _make_pl_context()
467
+ output = TensogramOutput(context, str(path), stack_pressure_levels=True)
468
+
469
+ state = _make_pl_state()
470
+ output.open(state)
471
+ output.write_step(state)
472
+ output.close()
473
+
474
+ tgm_file = tensogram.TensogramFile.open(str(path))
475
+ meta, _ = tgm_file[0]
476
+
477
+ for obj_idx in range(2, 2 + len(PL_PARAMS)):
478
+ entry = meta.base[obj_idx]
479
+ anemoi = entry["anemoi"]
480
+ mars = entry["mars"]
481
+ assert "levelist" not in anemoi
482
+ assert "level" not in anemoi
483
+ assert "level" not in mars
484
+ assert mars["levelist"] == sorted(PL_LEVELS)
485
+ assert mars["levtype"] == "pl"
486
+ assert "name" in entry
487
+ assert entry["name"] == mars["param"]
488
+
489
+
490
+ def test_stack_values_round_trip(tmp_path):
491
+ """Stacked values decode to the same data as the input, in level-sorted order."""
492
+ path = tmp_path / "stacked.tgm"
493
+ context = _make_pl_context(params=["t"], levels=[850, 500, 1000])
494
+ output = TensogramOutput(context, str(path), stack_pressure_levels=True)
495
+
496
+ rng = np.random.default_rng(7)
497
+ fields = {f"t{lv}": rng.random(N_GRID).astype(np.float32) for lv in [850, 500, 1000]}
498
+ state = {
499
+ "date": datetime(2024, 1, 1, 1),
500
+ "step": timedelta(hours=1),
501
+ "latitudes": np.linspace(-90, 90, N_GRID, dtype=np.float64),
502
+ "longitudes": np.linspace(0, 360, N_GRID, dtype=np.float64),
503
+ "fields": fields,
504
+ }
505
+
506
+ output.open(state)
507
+ output.write_step(state)
508
+ output.close()
509
+
510
+ tgm_file = tensogram.TensogramFile.open(str(path))
511
+ meta, objects = tgm_file[0]
512
+
513
+ mars = meta.base[2]["mars"]
514
+ assert mars["levelist"] == [500, 850, 1000]
515
+
516
+ _, arr = objects[2]
517
+ np.testing.assert_allclose(arr[:, 0], fields["t500"], rtol=1e-6)
518
+ np.testing.assert_allclose(arr[:, 1], fields["t850"], rtol=1e-6)
519
+ np.testing.assert_allclose(arr[:, 2], fields["t1000"], rtol=1e-6)
520
+
521
+
522
+ def test_stack_non_pl_fields_written_flat(tmp_path):
523
+ """Non-PL fields are written as individual objects even with stacking enabled."""
524
+ path = tmp_path / "mixed.tgm"
525
+ context = _make_pl_context(params=["t"], levels=[500, 850], extra_fields=["2t"])
526
+ output = TensogramOutput(context, str(path), stack_pressure_levels=True)
527
+
528
+ state = _make_pl_state(params=["t"], levels=[500, 850], extra_fields=["2t"])
529
+ output.open(state)
530
+ output.write_step(state)
531
+ output.close()
532
+
533
+ tgm_file = tensogram.TensogramFile.open(str(path))
534
+ meta, objects = tgm_file[0]
535
+
536
+ assert len(objects) == 4
537
+
538
+ flat_entries = [
539
+ meta.base[i]["anemoi"]
540
+ for i in range(2, len(objects))
541
+ if meta.base[i]["anemoi"].get("variable") == "2t"
542
+ ]
543
+ assert len(flat_entries) == 1
544
+ assert "levels" not in flat_entries[0]
545
+ assert "level" not in flat_entries[0]
546
+
547
+
548
+ def test_no_stack_level_metadata_correct(tmp_path):
549
+ """Without stacking, each PL field stores scalar 'level' and 'levtype'."""
550
+ path = tmp_path / "flat_pl.tgm"
551
+ context = _make_pl_context(params=["t"], levels=[500, 850])
552
+ output = TensogramOutput(context, str(path), stack_pressure_levels=False)
553
+
554
+ state = _make_pl_state(params=["t"], levels=[500, 850])
555
+ output.open(state)
556
+ output.write_step(state)
557
+ output.close()
558
+
559
+ tgm_file = tensogram.TensogramFile.open(str(path))
560
+ meta, objects = tgm_file[0]
561
+
562
+ assert len(objects) == 4
563
+
564
+ for obj_idx in range(2, 4):
565
+ mars = meta.base[obj_idx]["mars"]
566
+ anemoi = meta.base[obj_idx]["anemoi"]
567
+ assert "level" in mars
568
+ assert "levtype" in mars
569
+ assert mars["levtype"] == "pl"
570
+ assert mars["param"] == "t"
571
+ assert "levels" not in anemoi