pycontrails 0.53.0__cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pycontrails might be problematic. Click here for more details.

Files changed (109) hide show
  1. pycontrails/__init__.py +70 -0
  2. pycontrails/_version.py +16 -0
  3. pycontrails/core/__init__.py +30 -0
  4. pycontrails/core/aircraft_performance.py +641 -0
  5. pycontrails/core/airports.py +226 -0
  6. pycontrails/core/cache.py +881 -0
  7. pycontrails/core/coordinates.py +174 -0
  8. pycontrails/core/fleet.py +470 -0
  9. pycontrails/core/flight.py +2312 -0
  10. pycontrails/core/flightplan.py +220 -0
  11. pycontrails/core/fuel.py +140 -0
  12. pycontrails/core/interpolation.py +721 -0
  13. pycontrails/core/met.py +2833 -0
  14. pycontrails/core/met_var.py +307 -0
  15. pycontrails/core/models.py +1181 -0
  16. pycontrails/core/polygon.py +549 -0
  17. pycontrails/core/rgi_cython.cpython-313-x86_64-linux-gnu.so +0 -0
  18. pycontrails/core/vector.py +2191 -0
  19. pycontrails/datalib/__init__.py +12 -0
  20. pycontrails/datalib/_leo_utils/search.py +250 -0
  21. pycontrails/datalib/_leo_utils/static/bq_roi_query.sql +6 -0
  22. pycontrails/datalib/_leo_utils/vis.py +59 -0
  23. pycontrails/datalib/_met_utils/metsource.py +743 -0
  24. pycontrails/datalib/ecmwf/__init__.py +53 -0
  25. pycontrails/datalib/ecmwf/arco_era5.py +527 -0
  26. pycontrails/datalib/ecmwf/common.py +109 -0
  27. pycontrails/datalib/ecmwf/era5.py +538 -0
  28. pycontrails/datalib/ecmwf/era5_model_level.py +482 -0
  29. pycontrails/datalib/ecmwf/hres.py +782 -0
  30. pycontrails/datalib/ecmwf/hres_model_level.py +495 -0
  31. pycontrails/datalib/ecmwf/ifs.py +284 -0
  32. pycontrails/datalib/ecmwf/model_levels.py +79 -0
  33. pycontrails/datalib/ecmwf/static/model_level_dataframe_v20240418.csv +139 -0
  34. pycontrails/datalib/ecmwf/variables.py +256 -0
  35. pycontrails/datalib/gfs/__init__.py +28 -0
  36. pycontrails/datalib/gfs/gfs.py +646 -0
  37. pycontrails/datalib/gfs/variables.py +100 -0
  38. pycontrails/datalib/goes.py +772 -0
  39. pycontrails/datalib/landsat.py +568 -0
  40. pycontrails/datalib/sentinel.py +512 -0
  41. pycontrails/datalib/spire.py +739 -0
  42. pycontrails/ext/bada.py +41 -0
  43. pycontrails/ext/cirium.py +14 -0
  44. pycontrails/ext/empirical_grid.py +140 -0
  45. pycontrails/ext/synthetic_flight.py +426 -0
  46. pycontrails/models/__init__.py +1 -0
  47. pycontrails/models/accf.py +406 -0
  48. pycontrails/models/apcemm/__init__.py +8 -0
  49. pycontrails/models/apcemm/apcemm.py +983 -0
  50. pycontrails/models/apcemm/inputs.py +226 -0
  51. pycontrails/models/apcemm/static/apcemm_yaml_template.yaml +183 -0
  52. pycontrails/models/apcemm/utils.py +437 -0
  53. pycontrails/models/cocip/__init__.py +29 -0
  54. pycontrails/models/cocip/cocip.py +2617 -0
  55. pycontrails/models/cocip/cocip_params.py +299 -0
  56. pycontrails/models/cocip/cocip_uncertainty.py +285 -0
  57. pycontrails/models/cocip/contrail_properties.py +1517 -0
  58. pycontrails/models/cocip/output_formats.py +2261 -0
  59. pycontrails/models/cocip/radiative_forcing.py +1262 -0
  60. pycontrails/models/cocip/radiative_heating.py +520 -0
  61. pycontrails/models/cocip/unterstrasser_wake_vortex.py +403 -0
  62. pycontrails/models/cocip/wake_vortex.py +396 -0
  63. pycontrails/models/cocip/wind_shear.py +120 -0
  64. pycontrails/models/cocipgrid/__init__.py +9 -0
  65. pycontrails/models/cocipgrid/cocip_grid.py +2573 -0
  66. pycontrails/models/cocipgrid/cocip_grid_params.py +138 -0
  67. pycontrails/models/dry_advection.py +486 -0
  68. pycontrails/models/emissions/__init__.py +21 -0
  69. pycontrails/models/emissions/black_carbon.py +594 -0
  70. pycontrails/models/emissions/emissions.py +1353 -0
  71. pycontrails/models/emissions/ffm2.py +336 -0
  72. pycontrails/models/emissions/static/default-engine-uids.csv +239 -0
  73. pycontrails/models/emissions/static/edb-gaseous-v29b-engines.csv +596 -0
  74. pycontrails/models/emissions/static/edb-nvpm-v29b-engines.csv +215 -0
  75. pycontrails/models/humidity_scaling/__init__.py +37 -0
  76. pycontrails/models/humidity_scaling/humidity_scaling.py +1025 -0
  77. pycontrails/models/humidity_scaling/quantiles/era5-model-level-quantiles.pq +0 -0
  78. pycontrails/models/humidity_scaling/quantiles/era5-pressure-level-quantiles.pq +0 -0
  79. pycontrails/models/issr.py +210 -0
  80. pycontrails/models/pcc.py +327 -0
  81. pycontrails/models/pcr.py +154 -0
  82. pycontrails/models/ps_model/__init__.py +17 -0
  83. pycontrails/models/ps_model/ps_aircraft_params.py +376 -0
  84. pycontrails/models/ps_model/ps_grid.py +505 -0
  85. pycontrails/models/ps_model/ps_model.py +1017 -0
  86. pycontrails/models/ps_model/ps_operational_limits.py +540 -0
  87. pycontrails/models/ps_model/static/ps-aircraft-params-20240524.csv +68 -0
  88. pycontrails/models/ps_model/static/ps-synonym-list-20240524.csv +103 -0
  89. pycontrails/models/sac.py +459 -0
  90. pycontrails/models/tau_cirrus.py +168 -0
  91. pycontrails/physics/__init__.py +1 -0
  92. pycontrails/physics/constants.py +116 -0
  93. pycontrails/physics/geo.py +989 -0
  94. pycontrails/physics/jet.py +837 -0
  95. pycontrails/physics/thermo.py +451 -0
  96. pycontrails/physics/units.py +472 -0
  97. pycontrails/py.typed +0 -0
  98. pycontrails/utils/__init__.py +1 -0
  99. pycontrails/utils/dependencies.py +66 -0
  100. pycontrails/utils/iteration.py +13 -0
  101. pycontrails/utils/json.py +188 -0
  102. pycontrails/utils/temp.py +50 -0
  103. pycontrails/utils/types.py +165 -0
  104. pycontrails-0.53.0.dist-info/LICENSE +178 -0
  105. pycontrails-0.53.0.dist-info/METADATA +181 -0
  106. pycontrails-0.53.0.dist-info/NOTICE +43 -0
  107. pycontrails-0.53.0.dist-info/RECORD +109 -0
  108. pycontrails-0.53.0.dist-info/WHEEL +6 -0
  109. pycontrails-0.53.0.dist-info/top_level.txt +3 -0
@@ -0,0 +1,2833 @@
1
+ """Meteorology data models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import json
7
+ import logging
8
+ import pathlib
9
+ import typing
10
+ import warnings
11
+ from abc import ABC, abstractmethod
12
+ from collections.abc import (
13
+ Generator,
14
+ Hashable,
15
+ Iterable,
16
+ Iterator,
17
+ Mapping,
18
+ MutableMapping,
19
+ Sequence,
20
+ )
21
+ from contextlib import ExitStack
22
+ from datetime import datetime
23
+ from typing import (
24
+ TYPE_CHECKING,
25
+ Any,
26
+ Generic,
27
+ Literal,
28
+ TypeVar,
29
+ overload,
30
+ )
31
+
32
+ import numpy as np
33
+ import numpy.typing as npt
34
+ import pandas as pd
35
+ import xarray as xr
36
+ from overrides import overrides
37
+
38
+ from pycontrails.core import interpolation
39
+ from pycontrails.core import vector as vector_module
40
+ from pycontrails.core.cache import CacheStore, DiskCacheStore
41
+ from pycontrails.core.met_var import AirPressure, Altitude, MetVariable
42
+ from pycontrails.physics import units
43
+ from pycontrails.utils import dependencies
44
+ from pycontrails.utils import temp as temp_module
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+ # optional imports
49
+ if TYPE_CHECKING:
50
+ import open3d as o3d
51
+
52
+ XArrayType = TypeVar("XArrayType", xr.Dataset, xr.DataArray)
53
+ MetDataType = TypeVar("MetDataType", "MetDataset", "MetDataArray")
54
+ DatasetType = TypeVar("DatasetType", xr.Dataset, "MetDataset")
55
+
56
+ COORD_DTYPE = np.float64
57
+
58
+
59
+ class MetBase(ABC, Generic[XArrayType]):
60
+ """Abstract class for building Meteorology Data handling classes.
61
+
62
+ All support here should be generic to work on xr.DataArray
63
+ and xr.Dataset.
64
+ """
65
+
66
+ #: DataArray or Dataset
67
+ data: XArrayType
68
+
69
+ #: Cache datastore to use for :meth:`save` or :meth:`load`
70
+ cachestore: CacheStore | None
71
+
72
+ #: Default dimension order for DataArray or Dataset (x, y, z, t)
73
+ dim_order: tuple[Hashable, Hashable, Hashable, Hashable] = (
74
+ "longitude",
75
+ "latitude",
76
+ "level",
77
+ "time",
78
+ )
79
+
80
+ def __repr__(self) -> str:
81
+ data = getattr(self, "data", None)
82
+ return (
83
+ f"{self.__class__.__name__} with data:\n\n{data.__repr__() if data is not None else ''}"
84
+ )
85
+
86
+ def _repr_html_(self) -> str:
87
+ try:
88
+ return f"<b>{type(self).__name__}</b> with data:<br/ ><br/> {self.data._repr_html_()}"
89
+ except AttributeError:
90
+ return f"<b>{type(self).__name__}</b> without data"
91
+
92
+ def _validate_dim_contains_coords(self) -> None:
93
+ """Check that data contains four temporal-spatial coordinates.
94
+
95
+ Raises
96
+ ------
97
+ ValueError
98
+ If data does not contain all four coordinates (longitude, latitude, level, time).
99
+ """
100
+ for dim in self.dim_order:
101
+ if dim not in self.data.dims:
102
+ if dim == "level":
103
+ msg = (
104
+ f"Meteorology data must contain dimension '{dim}'. "
105
+ "For single level data, set 'level' coordinate to constant -1 "
106
+ "using `ds = ds.expand_dims({'level': [-1]})`"
107
+ )
108
+ else:
109
+ msg = f"Meteorology data must contain dimension '{dim}'."
110
+ raise ValueError(msg)
111
+
112
+ def _validate_longitude(self) -> None:
113
+ """Check longitude bounds.
114
+
115
+ Assumes ``longitude`` dimension is already sorted.
116
+
117
+ Raises
118
+ ------
119
+ ValueError
120
+ If longitude values are not contained in the interval [-180, 180].
121
+ """
122
+ longitude = self.indexes["longitude"].to_numpy()
123
+ if longitude.dtype != COORD_DTYPE:
124
+ raise ValueError(
125
+ "Longitude values must be of type float64. "
126
+ "Initiate with 'copy=True' to convert to float64. "
127
+ "Initiate with 'validate=False' to skip validation."
128
+ )
129
+
130
+ if self.is_wrapped:
131
+ # Relax verification if the longitude has already been processed and wrapped
132
+ if longitude[-1] > 360.0:
133
+ raise ValueError(
134
+ "Longitude contains values > 360. Shift to WGS84 with "
135
+ "'data.assign_coords(longitude=(((data.longitude + 180) % 360) - 180))'"
136
+ )
137
+ if longitude[0] < -360.0:
138
+ raise ValueError(
139
+ "Longitude contains values < -360. Shift to WGS84 with "
140
+ "'data.assign_coords(longitude=(((data.longitude + 180) % 360) - 180))'"
141
+ )
142
+ return
143
+
144
+ # Strict!
145
+ if longitude[-1] > 180.0:
146
+ raise ValueError(
147
+ "Longitude contains values > 180. Shift to WGS84 with "
148
+ "'data.assign_coords(longitude=(((data.longitude + 180) % 360) - 180))'"
149
+ )
150
+ if longitude[0] < -180.0:
151
+ raise ValueError(
152
+ "Longitude contains values < -180. Shift to WGS84 with "
153
+ "'data.assign_coords(longitude=(((data.longitude + 180) % 360) - 180))'"
154
+ )
155
+
156
+ def _validate_latitude(self) -> None:
157
+ """Check latitude bounds.
158
+
159
+ Assumes ``latitude`` dimension is already sorted.
160
+
161
+ Raises
162
+ ------
163
+ ValueError
164
+ If latitude values are not contained in the interval [-90, 90].
165
+ """
166
+ latitude = self.indexes["latitude"].to_numpy()
167
+ if latitude.dtype != COORD_DTYPE:
168
+ raise ValueError(
169
+ "Latitude values must be of type float64. "
170
+ "Initiate with 'copy=True' to convert to float64. "
171
+ "Initiate with 'validate=False' to skip validation."
172
+ )
173
+
174
+ if latitude[0] < -90.0:
175
+ raise ValueError(
176
+ "Latitude contains values < -90 . "
177
+ "Latitude values must be contained in the interval [-90, 90]."
178
+ )
179
+ if latitude[-1] > 90.0:
180
+ raise ValueError(
181
+ "Latitude contains values > 90 . "
182
+ "Latitude values must be contained in the interval [-90, 90]."
183
+ )
184
+
185
+ def _validate_sorting(self) -> None:
186
+ """Check that all coordinates are sorted.
187
+
188
+ Raises
189
+ ------
190
+ ValueError
191
+ If one of the coordinates is not sorted.
192
+ """
193
+ indexes = self.indexes
194
+ if not np.all(np.diff(indexes["time"]) > np.timedelta64(0, "ns")):
195
+ raise ValueError("Coordinate `time` not sorted. Initiate with `copy=True`.")
196
+ for coord in self.dim_order[:3]: # exclude time, the 4th dimension
197
+ if not np.all(np.diff(indexes[coord]) > 0.0):
198
+ raise ValueError(f"Coordinate '{coord}' not sorted. Initiate with 'copy=True'.")
199
+
200
+ def _validate_transpose(self) -> None:
201
+ """Check that data is transposed according to :attr:`dim_order`."""
202
+
203
+ def _check_da(da: xr.DataArray, key: Hashable | None = None) -> None:
204
+ if da.dims != self.dim_order:
205
+ if key is not None:
206
+ msg = (
207
+ f"Data dimension not transposed on variable '{key}'. Initiate with"
208
+ " 'copy=True'."
209
+ )
210
+ else:
211
+ msg = "Data dimension not transposed. Initiate with 'copy=True'."
212
+ raise ValueError(msg)
213
+
214
+ data = self.data
215
+ if isinstance(data, xr.DataArray):
216
+ _check_da(data)
217
+ return
218
+
219
+ for key, da in self.data.items():
220
+ _check_da(da, key)
221
+
222
+ def _validate_dims(self) -> None:
223
+ """Apply all validators."""
224
+ self._validate_dim_contains_coords()
225
+
226
+ # Apply this one first: validate_longitude and validate_latitude assume sorted
227
+ self._validate_sorting()
228
+ self._validate_longitude()
229
+ self._validate_latitude()
230
+ self._validate_transpose()
231
+
232
+ def _preprocess_dims(self, wrap_longitude: bool) -> None:
233
+ """Confirm DataArray or Dataset include required dimension in a consistent format.
234
+
235
+ Expects DataArray or Dataset to contain dimensions ``latitude`, ``longitude``, ``time``,
236
+ and ``level`` (in hPa/mbar).
237
+ Adds additional coordinate variables ``air_pressure`` and ``altitude`` coordinates
238
+ mapped to "level" dimension if "level" > 0.
239
+
240
+ Set ``level`` to -1 to signify single level.
241
+
242
+ .. versionchanged:: 0.40.0
243
+
244
+ All coordinate data (longitude, latitude, level) are promoted to ``float64``.
245
+ Auxiliary coordinates (altitude and air_pressure) are now cast to the same
246
+ dtype as the underlying grid data.
247
+
248
+
249
+ Parameters
250
+ ----------
251
+ wrap_longitude : bool
252
+ If True, ensure longitude values cover the interval ``[-180, 180]``.
253
+
254
+ Raises
255
+ ------
256
+ ValueError
257
+ Raises if required dimension names are not found
258
+ """
259
+ self._validate_dim_contains_coords()
260
+
261
+ # Ensure spatial coordinates all have dtype COORD_DTYPE
262
+ indexes = self.indexes
263
+ for coord in ("longitude", "latitude", "level"):
264
+ arr = indexes[coord].to_numpy()
265
+ if arr.dtype != COORD_DTYPE:
266
+ self.data[coord] = arr.astype(COORD_DTYPE)
267
+
268
+ # Ensure time is np.datetime64[ns]
269
+ self.data["time"] = self.data["time"].astype("datetime64[ns]", copy=False)
270
+
271
+ # sortby to ensure each coordinate has ascending order
272
+ self.data = self.data.sortby(list(self.dim_order), ascending=True)
273
+
274
+ if not self.is_wrapped:
275
+ # Ensure longitude is contained in interval [-180, 180)
276
+ # If longitude has value at 180, we might not want to shift it?
277
+ lon = self.indexes["longitude"].to_numpy()
278
+
279
+ # This longitude shifting can give rise to precision errors with float32
280
+ # Only shift if necessary
281
+ if np.any(lon >= 180.0) or np.any(lon < -180.0):
282
+ self.data = shift_longitude(self.data)
283
+ else:
284
+ self.data = self.data.sortby("longitude", ascending=True)
285
+
286
+ # wrap longitude, if requested
287
+ if wrap_longitude:
288
+ self.data = _wrap_longitude(self.data)
289
+
290
+ self._validate_longitude()
291
+ self._validate_latitude()
292
+
293
+ # transpose to have ordering (x, y, z, t, ...)
294
+ dim_order = [*self.dim_order, *(d for d in self.data.dims if d not in self.dim_order)]
295
+ self.data = self.data.transpose(*dim_order)
296
+
297
+ # single level data
298
+ if self.is_single_level:
299
+ # add level attributes to reflect surface level
300
+ level_attrs = self.data["level"].attrs
301
+ if not level_attrs:
302
+ level_attrs.update(units="", long_name="Single Level")
303
+ return
304
+
305
+ self.data = _add_vertical_coords(self.data)
306
+
307
+ @property
308
+ def hash(self) -> str:
309
+ """Generate a unique hash for this met instance.
310
+
311
+ Note this is not as robust as it could be since `repr`
312
+ cuts off.
313
+
314
+ Returns
315
+ -------
316
+ str
317
+ Unique hash for met instance (sha1)
318
+ """
319
+ _hash = repr(self.data)
320
+ return hashlib.sha1(bytes(_hash, "utf-8")).hexdigest()
321
+
322
+ @property
323
+ @abstractmethod
324
+ def size(self) -> int:
325
+ """Return the size of (each) array in underlying :attr:`data`.
326
+
327
+ Returns
328
+ -------
329
+ int
330
+ Total number of grid points in underlying data
331
+ """
332
+
333
+ @property
334
+ @abstractmethod
335
+ def shape(self) -> tuple[int, int, int, int]:
336
+ """Return the shape of the dimensions.
337
+
338
+ Returns
339
+ -------
340
+ tuple[int, int, int, int]
341
+ Shape of underlying data
342
+ """
343
+
344
+ @property
345
+ def coords(self) -> dict[str, np.ndarray]:
346
+ """Get coordinates of underlying :attr:`data` coordinates.
347
+
348
+ Only return non-dimension coordinates.
349
+
350
+ See:
351
+ http://xarray.pydata.org/en/stable/user-guide/data-structures.html#coordinates
352
+
353
+ Returns
354
+ -------
355
+ dict[str, np.ndarray]
356
+ Dictionary of coordinates
357
+ """
358
+ variables = self.indexes
359
+ return {
360
+ "longitude": variables["longitude"].to_numpy(),
361
+ "latitude": variables["latitude"].to_numpy(),
362
+ "level": variables["level"].to_numpy(),
363
+ "time": variables["time"].to_numpy(),
364
+ }
365
+
366
+ @property
367
+ def variables(self) -> dict[Hashable, pd.Index]:
368
+ """See :attr:`indexes`."""
369
+ warnings.warn(
370
+ "The 'variables' property is deprecated and will be removed in a future release. "
371
+ "Use 'indexes' instead.",
372
+ DeprecationWarning,
373
+ )
374
+ return self.indexes
375
+
376
+ @property
377
+ def indexes(self) -> dict[Hashable, pd.Index]:
378
+ """Low level access to underlying :attr:`data` indexes.
379
+
380
+ This method is typically is faster for accessing coordinate indexes.
381
+
382
+ .. versionadded:: 0.25.2
383
+
384
+ Returns
385
+ -------
386
+ dict[Hashable, pd.Index]
387
+ Dictionary of indexes.
388
+
389
+ Examples
390
+ --------
391
+ >>> from pycontrails.datalib.ecmwf import ERA5
392
+ >>> times = (datetime(2022, 3, 1, 12), datetime(2022, 3, 1, 13))
393
+ >>> variables = "air_temperature", "specific_humidity"
394
+ >>> levels = [200, 300]
395
+ >>> era5 = ERA5(times, variables, levels)
396
+ >>> mds = era5.open_metdataset()
397
+ >>> mds.indexes["level"].to_numpy()
398
+ array([200., 300.])
399
+
400
+ >>> mda = mds["air_temperature"]
401
+ >>> mda.indexes["level"].to_numpy()
402
+ array([200., 300.])
403
+ """
404
+ return {k: v.index for k, v in self.data._indexes.items()} # type: ignore[attr-defined]
405
+
406
+ @property
407
+ def is_wrapped(self) -> bool:
408
+ """Check if the longitude dimension covers the closed interval ``[-180, 180]``.
409
+
410
+ Assumes the longitude dimension is sorted (this is established by the
411
+ :class:`MetDataset` or :class:`MetDataArray` constructor).
412
+
413
+ .. versionchanged 0.26.0::
414
+
415
+ The previous implementation checked for the minimum and maximum longitude
416
+ dimension values to be duplicated. The current implementation only checks for
417
+ that the interval ``[-180, 180]`` is covered by the longitude dimension. The
418
+ :func:`pycontrails.physics.geo.advect_longitude` is designed for compatibility
419
+ with this convention.
420
+
421
+ Returns
422
+ -------
423
+ bool
424
+ True if longitude coordinates cover ``[-180, 180]``
425
+
426
+ See Also
427
+ --------
428
+ :func:`pycontrails.physics.geo.advect_longitude`
429
+ """
430
+ longitude = self.indexes["longitude"].to_numpy()
431
+ return _is_wrapped(longitude)
432
+
433
+ @property
434
+ def is_single_level(self) -> bool:
435
+ """Check if instance contains "single level" or "surface level" data.
436
+
437
+ This method checks if ``level`` dimension contains a single value equal
438
+ to -1, the pycontrails convention for surface only data.
439
+
440
+ Returns
441
+ -------
442
+ bool
443
+ If instance contains single level data.
444
+ """
445
+ level = self.indexes["level"].to_numpy()
446
+ return len(level) == 1 and level[0] == -1
447
+
448
+ @abstractmethod
449
+ def broadcast_coords(self, name: str) -> xr.DataArray:
450
+ """Broadcast coordinates along other dimensions.
451
+
452
+ Parameters
453
+ ----------
454
+ name : str
455
+ Coordinate/dimension name to broadcast.
456
+ Can be a dimension or non-dimension coordinates.
457
+
458
+ Returns
459
+ -------
460
+ xr.DataArray
461
+ DataArray of the coordinate broadcasted along all other dimensions.
462
+ The DataArray will have the same shape as the gridded data.
463
+ """
464
+
465
+ def _save(self, dataset: xr.Dataset, **kwargs: Any) -> list[str]:
466
+ """Save dataset to netcdf files named for the met hash and hour.
467
+
468
+ Does not yet save in parallel.
469
+
470
+ .. versionchanged::0.34.1
471
+
472
+ If :attr:`cachestore` is None, this method assigns it
473
+ to new :class:`DiskCacheStore`.
474
+
475
+ Parameters
476
+ ----------
477
+ dataset : xr.Dataset
478
+ Dataset to save
479
+ **kwargs
480
+ Keyword args passed directly on to :func:`xarray.save_mfdataset`
481
+
482
+ Returns
483
+ -------
484
+ list[str]
485
+ List of filenames saved
486
+ """
487
+ self.cachestore = self.cachestore or DiskCacheStore()
488
+
489
+ # group by hour and save one dataset for each hour to temp file
490
+ times, datasets = zip(*dataset.groupby("time", squeeze=False), strict=True)
491
+
492
+ # Open ExitStack to control temp_file context manager
493
+ with ExitStack() as stack:
494
+ xarray_temp_filenames = [stack.enter_context(temp_module.temp_file()) for _ in times]
495
+ xr.save_mfdataset(datasets, xarray_temp_filenames, **kwargs)
496
+
497
+ # set filenames by hash
498
+ filenames = [f"{self.hash}-{t_idx}.nc" for t_idx, _ in enumerate(times)]
499
+
500
+ # put each hourly file into cache
501
+ self.cachestore.put_multiple(xarray_temp_filenames, filenames)
502
+
503
+ return filenames
504
+
505
+ def __len__(self) -> int:
506
+ return self.data.__len__()
507
+
508
+ @property
509
+ def attrs(self) -> dict[Hashable, Any]:
510
+ """Pass through to :attr:`self.data.attrs`."""
511
+ return self.data.attrs
512
+
513
+ @abstractmethod
514
+ def downselect(self, bbox: tuple[float, ...]) -> MetBase:
515
+ """Downselect met data within spatial bounding box.
516
+
517
+ Parameters
518
+ ----------
519
+ bbox : list[float]
520
+ List of coordinates defining a spatial bounding box in WGS84 coordinates.
521
+ For 2D queries, list is [west, south, east, north].
522
+ For 3D queries, list is [west, south, min-level, east, north, max-level]
523
+ with level defined in [:math:`hPa`].
524
+
525
+
526
+ Returns
527
+ -------
528
+ MetBase
529
+ Return downselected data
530
+ """
531
+
532
+ @property
533
+ def is_zarr(self) -> bool:
534
+ """Check if underlying :attr:`data` is sourced from a Zarr group.
535
+
536
+ Implementation is very brittle, and may break as external libraries change.
537
+
538
+ Some ``dask`` intermediate artifact is cached when this is called. Typically,
539
+ subsequent calls to this method are much faster than the initial call.
540
+
541
+ .. versionadded:: 0.26.0
542
+
543
+ Returns
544
+ -------
545
+ bool
546
+ If ``data`` is based on a Zarr group.
547
+ """
548
+ return _is_zarr(self.data)
549
+
550
+ def downselect_met(
551
+ self,
552
+ met: MetDataType,
553
+ *,
554
+ longitude_buffer: tuple[float, float] = (0.0, 0.0),
555
+ latitude_buffer: tuple[float, float] = (0.0, 0.0),
556
+ level_buffer: tuple[float, float] = (0.0, 0.0),
557
+ time_buffer: tuple[np.timedelta64, np.timedelta64] = (
558
+ np.timedelta64(0, "h"),
559
+ np.timedelta64(0, "h"),
560
+ ),
561
+ copy: bool = True,
562
+ ) -> MetDataType:
563
+ """Downselect ``met`` to encompass a spatiotemporal region of the data.
564
+
565
+ .. warning::
566
+
567
+ This method is analogous to :meth:`GeoVectorDataset.downselect_met`.
568
+ It does not change the instance data, but instead operates on the
569
+ ``met`` input. This method is different from :meth:`downselect` which
570
+ operates on the instance data.
571
+
572
+ Parameters
573
+ ----------
574
+ met : MetDataset | MetDataArray
575
+ MetDataset or MetDataArray to downselect.
576
+ longitude_buffer : tuple[float, float], optional
577
+ Extend longitude domain past by ``longitude_buffer[0]`` on the low side
578
+ and ``longitude_buffer[1]`` on the high side.
579
+ Units must be the same as class coordinates.
580
+ Defaults to ``(0, 0)`` degrees.
581
+ latitude_buffer : tuple[float, float], optional
582
+ Extend latitude domain past by ``latitude_buffer[0]`` on the low side
583
+ and ``latitude_buffer[1]`` on the high side.
584
+ Units must be the same as class coordinates.
585
+ Defaults to ``(0, 0)`` degrees.
586
+ level_buffer : tuple[float, float], optional
587
+ Extend level domain past by ``level_buffer[0]`` on the low side
588
+ and ``level_buffer[1]`` on the high side.
589
+ Units must be the same as class coordinates.
590
+ Defaults to ``(0, 0)`` [:math:`hPa`].
591
+ time_buffer : tuple[np.timedelta64, np.timedelta64], optional
592
+ Extend time domain past by ``time_buffer[0]`` on the low side
593
+ and ``time_buffer[1]`` on the high side.
594
+ Units must be the same as class coordinates.
595
+ Defaults to ``(np.timedelta64(0, "h"), np.timedelta64(0, "h"))``.
596
+ copy : bool
597
+ If returned object is a copy or view of the original. True by default.
598
+
599
+ Returns
600
+ -------
601
+ MetDataset | MetDataArray
602
+ Copy of downselected MetDataset or MetDataArray.
603
+ """
604
+ indexes = self.indexes
605
+ lon = indexes["longitude"].to_numpy()
606
+ lat = indexes["latitude"].to_numpy()
607
+ level = indexes["level"].to_numpy()
608
+ time = indexes["time"].to_numpy()
609
+
610
+ vector = vector_module.GeoVectorDataset(
611
+ longitude=[lon.min(), lon.max()],
612
+ latitude=[lat.min(), lat.max()],
613
+ level=[level.min(), level.max()],
614
+ time=[time.min(), time.max()],
615
+ )
616
+
617
+ return vector.downselect_met(
618
+ met,
619
+ longitude_buffer=longitude_buffer,
620
+ latitude_buffer=latitude_buffer,
621
+ level_buffer=level_buffer,
622
+ time_buffer=time_buffer,
623
+ copy=copy,
624
+ )
625
+
626
+
627
+ class MetDataset(MetBase):
628
+ """Meteorological dataset with multiple variables.
629
+
630
+ Composition around xr.Dataset to enforce certain
631
+ variables and dimensions for internal usage
632
+
633
+ Parameters
634
+ ----------
635
+ data : xr.Dataset
636
+ :class:`xarray.Dataset` containing meteorological variables and coordinates
637
+ cachestore : :class:`CacheStore`, optional
638
+ Cache datastore for staging intermediates with :meth:`save`.
639
+ Defaults to None.
640
+ wrap_longitude : bool, optional
641
+ Wrap data along the longitude dimension. If True, duplicate and shift longitude
642
+ values (ie, -180 -> 180) to ensure that the longitude dimension covers the entire
643
+ interval ``[-180, 180]``. Defaults to False.
644
+ copy : bool, optional
645
+ Copy data on construction. Defaults to True.
646
+ attrs : dict[str, Any], optional
647
+ Attributes to add to :attr:`data.attrs`. Defaults to None.
648
+ Generally, pycontrails :class:`Models` may use the following attributes:
649
+
650
+ - ``provider``: Name of the data provider (e.g. "ECMWF").
651
+ - ``dataset``: Name of the dataset (e.g. "ERA5").
652
+ - ``product``: Name of the product type (e.g. "reanalysis").
653
+
654
+ **attrs_kwargs : Any
655
+ Keyword arguments to add to :attr:`data.attrs`. Defaults to None.
656
+
657
+ Examples
658
+ --------
659
+ >>> import numpy as np
660
+ >>> import pandas as pd
661
+ >>> import xarray as xr
662
+ >>> from pycontrails.datalib.ecmwf import ERA5
663
+
664
+ >>> time = ("2022-03-01T00", "2022-03-01T02")
665
+ >>> variables = ["air_temperature", "specific_humidity"]
666
+ >>> pressure_levels = [200, 250, 300]
667
+ >>> era5 = ERA5(time, variables, pressure_levels)
668
+
669
+ >>> # Open directly as `MetDataset`
670
+ >>> met = era5.open_metdataset()
671
+ >>> # Use `data` attribute to access `xarray` object
672
+ >>> assert isinstance(met.data, xr.Dataset)
673
+
674
+ >>> # Alternatively, open with `xarray` and cast to `MetDataset`
675
+ >>> ds = xr.open_mfdataset(era5._cachepaths)
676
+ >>> met = MetDataset(ds)
677
+
678
+ >>> # Access sub-`DataArrays`
679
+ >>> mda = met["t"] # `MetDataArray` instance, needed for interpolation operations
680
+ >>> da = mda.data # Underlying `xarray` object
681
+
682
+ >>> # Check out a few values
683
+ >>> da[5:8, 5:8, 1, 1].values
684
+ array([[224.08959005, 224.41374427, 224.75945349],
685
+ [224.09456429, 224.42037658, 224.76525676],
686
+ [224.10036756, 224.42617985, 224.77106004]])
687
+
688
+ >>> # Mean temperature over entire array
689
+ >>> da.mean().load().item()
690
+ 223.5083
691
+ """
692
+
693
+ data: xr.Dataset
694
+
695
+ def __init__(
696
+ self,
697
+ data: xr.Dataset,
698
+ cachestore: CacheStore | None = None,
699
+ wrap_longitude: bool = False,
700
+ copy: bool = True,
701
+ attrs: dict[str, Any] | None = None,
702
+ **attrs_kwargs: Any,
703
+ ) -> None:
704
+ self.cachestore = cachestore
705
+
706
+ data.attrs.update(attrs or {}, **attrs_kwargs)
707
+
708
+ # if input is already a Dataset, copy into data
709
+ if not isinstance(data, xr.Dataset):
710
+ raise TypeError("Input 'data' must be an xarray Dataset")
711
+
712
+ # copy Dataset into data
713
+ if copy:
714
+ self.data = data.copy()
715
+ self._preprocess_dims(wrap_longitude)
716
+
717
+ else:
718
+ if wrap_longitude:
719
+ raise ValueError("Set 'copy=True' when using 'wrap_longitude=True'.")
720
+ self.data = data
721
+ self._validate_dims()
722
+ if not self.is_single_level:
723
+ self.data = _add_vertical_coords(self.data)
724
+
725
+ def __getitem__(self, key: Hashable) -> MetDataArray:
726
+ """Return DataArray of variable ``key`` cast to a :class:`MetDataArray` object.
727
+
728
+ Parameters
729
+ ----------
730
+ key : Hashable
731
+ Variable name
732
+
733
+ Returns
734
+ -------
735
+ MetDataArray
736
+ MetDataArray instance associated to :attr:`data` variable `key`
737
+
738
+ Raises
739
+ ------
740
+ KeyError
741
+ If ``key`` not found in :attr:`data`
742
+ """
743
+ try:
744
+ da = self.data[key]
745
+ except KeyError as e:
746
+ raise KeyError(
747
+ f"Variable {key} not found. Available variables: {', '.join(self.data.data_vars)}. "
748
+ "To get items (e.g. `time` or `level`) from underlying `xr.Dataset` object, "
749
+ "use the `data` attribute."
750
+ ) from e
751
+ return MetDataArray(da, copy=False, validate=False)
752
+
753
+ def get(self, key: str, default_value: Any = None) -> Any:
754
+ """Shortcut to :meth:`data.get(k, v)` method.
755
+
756
+ Parameters
757
+ ----------
758
+ key : str
759
+ Key to get from :attr:`data`
760
+ default_value : Any, optional
761
+ Return `default_value` if `key` not in :attr:`data`, by default `None`
762
+
763
+ Returns
764
+ -------
765
+ Any
766
+ Values returned from :attr:`data.get(key, default_value)`
767
+ """
768
+ return self.data.get(key, default_value)
769
+
770
+ def __setitem__(
771
+ self,
772
+ key: Hashable | list[Hashable] | Mapping,
773
+ value: Any,
774
+ ) -> None:
775
+ """Shortcut to set data variable on :attr:`data`.
776
+
777
+ Warns if ``key`` is already present in dataset.
778
+
779
+ Parameters
780
+ ----------
781
+ key : Hashable | list[Hashable] | Mapping
782
+ Variable name
783
+ value : Any
784
+ Value to set to variable names
785
+
786
+ See Also
787
+ --------
788
+ - :class:`xarray.Dataset.__setitem__`
789
+ """
790
+
791
+ # pull data of MetDataArray value
792
+ if isinstance(value, MetDataArray):
793
+ value = value.data
794
+
795
+ # warn if key is already in Dataset
796
+ override_keys: list[Hashable] = []
797
+
798
+ if isinstance(key, Hashable):
799
+ if key in self:
800
+ override_keys = [key]
801
+
802
+ # xarray.core.utils.is_dict_like
803
+ # https://github.com/pydata/xarray/blob/4cae8d0ec04195291b2315b1f21d846c2bad61ff/xarray/core/utils.py#L244
804
+ elif xr.core.utils.is_dict_like(key) or isinstance(key, list):
805
+ override_keys = [k for k in key if k in self.data]
806
+
807
+ if override_keys:
808
+ warnings.warn(
809
+ f"Overwriting data in keys `{override_keys}`. "
810
+ "Use `.update(...)` to suppress warning."
811
+ )
812
+
813
+ self.data.__setitem__(key, value)
814
+
815
+ def update(self, other: MutableMapping | None = None, **kwargs: Any) -> None:
816
+ """Shortcut to :meth:`data.update`.
817
+
818
+ See :meth:`xarray.Dataset.update` for reference.
819
+
820
+ Parameters
821
+ ----------
822
+ other : MutableMapping
823
+ Variables with which to update this dataset
824
+ **kwargs : Any
825
+ Variables defined by keyword arguments. If a variable exists both in
826
+ ``other`` and as a keyword argument, the keyword argument takes
827
+ precedence.
828
+
829
+ See Also
830
+ --------
831
+ - :meth:`xarray.Dataset.update`
832
+ """
833
+ other = other or {}
834
+ other.update(kwargs)
835
+
836
+ # pull data of MetDataArray value
837
+ for k, v in other.items():
838
+ if isinstance(v, MetDataArray):
839
+ other[k] = v.data
840
+
841
+ self.data.update(other)
842
+
843
+ def __iter__(self) -> Iterator[str]:
844
+ """Allow for the use as "key" in self.met, where "key" is a data variable."""
845
+ # From the typing perspective, `iter(self.data)`` returns Hashables (not
846
+ # necessarily strs). In everything we do, we use str variables.
847
+ # If we decide to extend this to support Hashable, we'll also want to
848
+ # change VectorDataset -- the underlying :attr:`data` should then be
849
+ # changed to `dict[Hashable, np.ndarray]`.
850
+ for key in self.data:
851
+ yield str(key)
852
+
853
+ def __contains__(self, key: Hashable) -> bool:
854
+ """Check if key ``key`` is in :attr:`data`.
855
+
856
+ Parameters
857
+ ----------
858
+ key : Hashable
859
+ Key to check
860
+
861
+ Returns
862
+ -------
863
+ bool
864
+ True if ``key`` is in :attr:`data`, False otherwise
865
+ """
866
+ return key in self.data
867
+
868
+ @property
869
+ @overrides
870
+ def shape(self) -> tuple[int, int, int, int]:
871
+ sizes = self.data.sizes
872
+ return sizes["longitude"], sizes["latitude"], sizes["level"], sizes["time"]
873
+
874
+ @property
875
+ @overrides
876
+ def size(self) -> int:
877
+ return np.prod(self.shape).item()
878
+
879
+ def copy(self) -> MetDataset:
880
+ """Create a copy of the current class.
881
+
882
+ Returns
883
+ -------
884
+ MetDataset
885
+ MetDataset copy
886
+ """
887
+ return MetDataset(
888
+ self.data,
889
+ cachestore=self.cachestore,
890
+ copy=True, # True by default, but being extra explicit
891
+ )
892
+
893
+ def ensure_vars(
894
+ self,
895
+ vars: MetVariable | str | Sequence[MetVariable | str | Sequence[MetVariable]],
896
+ raise_error: bool = True,
897
+ ) -> list[str]:
898
+ """Ensure variables exist in xr.Dataset.
899
+
900
+ Parameters
901
+ ----------
902
+ vars : MetVariable | str | Sequence[MetVariable | str | list[MetVariable]]
903
+ List of MetVariable (or string key), or individual MetVariable (or string key).
904
+ If ``vars`` contains an element with a list[MetVariable], then
905
+ only one variable in the list must be present in dataset.
906
+ raise_error : bool, optional
907
+ Raise KeyError if data does not contain variables.
908
+ Defaults to True.
909
+
910
+ Returns
911
+ -------
912
+ list[str]
913
+ List of met keys verified in MetDataset.
914
+ Returns an empty list if any MetVariable is missing.
915
+
916
+ Raises
917
+ ------
918
+ KeyError
919
+ Raises when dataset does not contain variable in ``vars``
920
+ """
921
+ if isinstance(vars, MetVariable | str):
922
+ vars = (vars,)
923
+
924
+ met_keys: list[str] = []
925
+ for variable in vars:
926
+ met_key: str | None = None
927
+
928
+ # input is a MetVariable or str
929
+ if isinstance(variable, MetVariable):
930
+ if (key := variable.standard_name) in self:
931
+ met_key = key
932
+ elif isinstance(variable, str):
933
+ if variable in self:
934
+ met_key = variable
935
+
936
+ # otherwise, assume input is an sequence
937
+ # Sequence[MetVariable] means that any variable in list will work
938
+ else:
939
+ for v in variable:
940
+ if (key := v.standard_name) in self:
941
+ met_key = key
942
+ break
943
+
944
+ if met_key is None:
945
+ if not raise_error:
946
+ return []
947
+
948
+ if isinstance(variable, MetVariable):
949
+ raise KeyError(f"Dataset does not contain variable `{variable.standard_name}`")
950
+ if isinstance(variable, list):
951
+ raise KeyError(
952
+ "Dataset does not contain one of variables "
953
+ f"`{[v.standard_name for v in variable]}`"
954
+ )
955
+ raise KeyError(f"Dataset does not contain variable `{variable}`")
956
+
957
+ met_keys.append(met_key)
958
+
959
+ return met_keys
960
+
961
+ def save(self, **kwargs: Any) -> list[str]:
962
+ """Save intermediate to :attr:`cachestore` as netcdf.
963
+
964
+ Load and restore using :meth:`load`.
965
+
966
+ Parameters
967
+ ----------
968
+ **kwargs : Any
969
+ Keyword arguments passed directly to :meth:`xarray.Dataset.to_netcdf`
970
+
971
+ Returns
972
+ -------
973
+ list[str]
974
+ Returns filenames saved
975
+ """
976
+ return self._save(self.data, **kwargs)
977
+
978
+ @classmethod
979
+ def load(
980
+ cls,
981
+ hash: str,
982
+ cachestore: CacheStore | None = None,
983
+ chunks: dict[str, int] | None = None,
984
+ ) -> MetDataset:
985
+ """Load saved intermediate from :attr:`cachestore`.
986
+
987
+ Parameters
988
+ ----------
989
+ hash : str
990
+ Saved hash to load.
991
+ cachestore : :class:`CacheStore`, optional
992
+ Cache datastore to use for sourcing files.
993
+ Defaults to DiskCacheStore.
994
+ chunks : dict[str: int], optional
995
+ Chunks kwarg passed to :func:`xarray.open_mfdataset()` when opening files.
996
+
997
+ Returns
998
+ -------
999
+ MetDataset
1000
+ New MetDataArray with loaded data.
1001
+ """
1002
+ cachestore = cachestore or DiskCacheStore()
1003
+ chunks = chunks or {}
1004
+ data = _load(hash, cachestore, chunks)
1005
+ return cls(data)
1006
+
1007
+ def wrap_longitude(self) -> MetDataset:
1008
+ """Wrap longitude coordinates.
1009
+
1010
+ Returns
1011
+ -------
1012
+ MetDataset
1013
+ Copy of MetDataset with wrapped longitude values.
1014
+ Returns copy of current MetDataset when longitude values are already wrapped
1015
+ """
1016
+ return MetDataset(
1017
+ _wrap_longitude(self.data),
1018
+ cachestore=self.cachestore,
1019
+ )
1020
+
1021
+ @overrides
1022
+ def broadcast_coords(self, name: str) -> xr.DataArray:
1023
+ da = xr.ones_like(self.data[next(iter(self.data.keys()))]) * self.data[name]
1024
+ da.name = name
1025
+
1026
+ return da
1027
+
1028
+ @overrides
1029
+ def downselect(self, bbox: tuple[float, ...]) -> MetDataset:
1030
+ data = downselect(self.data, bbox)
1031
+ return MetDataset(data, cachestore=self.cachestore, copy=False)
1032
+
1033
+ def to_vector(self, transfer_attrs: bool = True) -> vector_module.GeoVectorDataset:
1034
+ """Convert a :class:`MetDataset` to a :class:`GeoVectorDataset` by raveling data.
1035
+
1036
+ If :attr:`data` is lazy, it will be loaded.
1037
+
1038
+ Parameters
1039
+ ----------
1040
+ transfer_attrs : bool, optional
1041
+ Transfer attributes from :attr:`data` to output :class:`GeoVectorDataset`.
1042
+ By default, True, meaning that attributes are transferred.
1043
+
1044
+ Returns
1045
+ -------
1046
+ GeoVectorDataset
1047
+ Converted :class:`GeoVectorDataset`. The variables on the returned instance
1048
+ include all of those on the input instance, plus the four core spatial temporal
1049
+ variables.
1050
+
1051
+ Examples
1052
+ --------
1053
+ >>> from pycontrails.datalib.ecmwf import ERA5
1054
+ >>> times = "2022-03-01", "2022-03-01T01"
1055
+ >>> variables = ["air_temperature", "specific_humidity"]
1056
+ >>> levels = [250, 200]
1057
+ >>> era5 = ERA5(time=times, variables=variables, pressure_levels=levels)
1058
+ >>> met = era5.open_metdataset()
1059
+ >>> met.to_vector(transfer_attrs=False)
1060
+ GeoVectorDataset [6 keys x 4152960 length, 1 attributes]
1061
+ Keys: longitude, latitude, level, time, air_temperature, ..., specific_humidity
1062
+ Attributes:
1063
+ time [2022-03-01 00:00:00, 2022-03-01 01:00:00]
1064
+ longitude [-180.0, 179.75]
1065
+ latitude [-90.0, 90.0]
1066
+ altitude [10362.8, 11783.9]
1067
+ crs EPSG:4326
1068
+
1069
+ """
1070
+ coords_keys = self.data.dims
1071
+ indexes = self.indexes
1072
+ coords_vals = [indexes[key].values for key in coords_keys]
1073
+ coords_meshes = np.meshgrid(*coords_vals, indexing="ij")
1074
+ raveled_coords = (mesh.ravel() for mesh in coords_meshes)
1075
+ data = dict(zip(coords_keys, raveled_coords, strict=True))
1076
+
1077
+ out = vector_module.GeoVectorDataset(data, copy=False)
1078
+ for key, da in self.data.items():
1079
+ # The call to .values here will load the data if it is lazy
1080
+ out[key] = da.values.ravel() # type: ignore[index]
1081
+
1082
+ if transfer_attrs:
1083
+ out.attrs.update(self.attrs) # type: ignore[arg-type]
1084
+
1085
+ return out
1086
+
1087
+ def _get_pycontrails_attr_template(
1088
+ self,
1089
+ name: str,
1090
+ supported: tuple[str, ...],
1091
+ examples: dict[str, str],
1092
+ ) -> str:
1093
+ """Look up an attribute with a custom error message."""
1094
+ try:
1095
+ out = self.attrs[name]
1096
+ except KeyError as e:
1097
+ msg = f"Specify '{name}' attribute on underlying dataset."
1098
+
1099
+ for i, (k, v) in enumerate(examples.items()):
1100
+ if i == 0:
1101
+ msg = f"{msg} For example, set attrs['{name}'] = '{k}' for {v}."
1102
+ else:
1103
+ msg = f"{msg} Set attrs['{name}'] = '{k}' for {v}."
1104
+ raise KeyError(msg) from e
1105
+
1106
+ if out not in supported:
1107
+ warnings.warn(
1108
+ f"Unknown {name} '{out}'. Data may not be processed correctly. "
1109
+ f"Known {name}s are {supported}. Contact the pycontrails "
1110
+ "developers if you believe this is an error."
1111
+ )
1112
+
1113
+ return out
1114
+
1115
+ @property
1116
+ def provider_attr(self) -> str:
1117
+ """Look up the 'provider' attribute with a custom error message.
1118
+
1119
+ Returns
1120
+ -------
1121
+ str
1122
+ Provider of the data. If not one of 'ECMWF' or 'NCEP',
1123
+ a warning is issued.
1124
+ """
1125
+ supported = ("ECMWF", "NCEP")
1126
+ examples = {"ECMWF": "data provided by ECMWF", "NCEP": "GFS data"}
1127
+ return self._get_pycontrails_attr_template("provider", supported, examples)
1128
+
1129
+ @property
1130
+ def dataset_attr(self) -> str:
1131
+ """Look up the 'dataset' attribute with a custom error message.
1132
+
1133
+ Returns
1134
+ -------
1135
+ str
1136
+ Dataset of the data. If not one of 'ERA5', 'HRES', 'IFS',
1137
+ or 'GFS', a warning is issued.
1138
+ """
1139
+ supported = ("ERA5", "HRES", "IFS", "GFS")
1140
+ examples = {
1141
+ "ERA5": "ECMWF ERA5 reanalysis data",
1142
+ "HRES": "ECMWF HRES forecast data",
1143
+ "GFS": "NCEP GFS forecast data",
1144
+ }
1145
+ return self._get_pycontrails_attr_template("dataset", supported, examples)
1146
+
1147
+ @property
1148
+ def product_attr(self) -> str:
1149
+ """Look up the 'product' attribute with a custom error message.
1150
+
1151
+ Returns
1152
+ -------
1153
+ str
1154
+ Product of the data. If not one of 'forecast', 'ensemble', or 'reanalysis',
1155
+ a warning is issued.
1156
+
1157
+ """
1158
+ supported = ("reanalysis", "forecast", "ensemble")
1159
+ examples = {
1160
+ "reanalysis": "ECMWF ERA5 reanalysis data",
1161
+ "ensemble": "ECMWF ERA5 ensemble member data",
1162
+ }
1163
+ return self._get_pycontrails_attr_template("product", supported, examples)
1164
+
1165
+ def standardize_variables(self, variables: Iterable[MetVariable]) -> None:
1166
+ """Standardize variables **in-place**.
1167
+
1168
+ Parameters
1169
+ ----------
1170
+ variables : Iterable[MetVariable]
1171
+ Data source variables
1172
+
1173
+ See Also
1174
+ --------
1175
+ :func:`standardize_variables`
1176
+ """
1177
+ standardize_variables(self, variables)
1178
+
1179
+ @classmethod
1180
+ def from_coords(
1181
+ cls,
1182
+ longitude: npt.ArrayLike | float,
1183
+ latitude: npt.ArrayLike | float,
1184
+ level: npt.ArrayLike | float,
1185
+ time: npt.ArrayLike | np.datetime64,
1186
+ ) -> MetDataset:
1187
+ r"""Create a :class:`MetDataset` containing a coordinate skeleton from coordinate arrays.
1188
+
1189
+ Parameters
1190
+ ----------
1191
+ longitude, latitude : npt.ArrayLike | float
1192
+ Horizontal coordinates, in [:math:`\deg`]
1193
+ level : npt.ArrayLike | float
1194
+ Vertical coordinate, in [:math:`hPa`]
1195
+ time: npt.ArrayLike | np.datetime64,
1196
+ Temporal coordinates, in [:math:`UTC`]. Will be sorted.
1197
+
1198
+ Returns
1199
+ -------
1200
+ MetDataset
1201
+ MetDataset with no variables.
1202
+
1203
+ Examples
1204
+ --------
1205
+ >>> # Create skeleton MetDataset
1206
+ >>> longitude = np.arange(0, 10, 0.5)
1207
+ >>> latitude = np.arange(0, 10, 0.5)
1208
+ >>> level = [250, 300]
1209
+ >>> time = np.datetime64("2019-01-01")
1210
+ >>> met = MetDataset.from_coords(longitude, latitude, level, time)
1211
+ >>> met
1212
+ MetDataset with data:
1213
+ <xarray.Dataset> Size: 360B
1214
+ Dimensions: (longitude: 20, latitude: 20, level: 2, time: 1)
1215
+ Coordinates:
1216
+ * longitude (longitude) float64 160B 0.0 0.5 1.0 1.5 ... 8.0 8.5 9.0 9.5
1217
+ * latitude (latitude) float64 160B 0.0 0.5 1.0 1.5 ... 8.0 8.5 9.0 9.5
1218
+ * level (level) float64 16B 250.0 300.0
1219
+ * time (time) datetime64[ns] 8B 2019-01-01
1220
+ air_pressure (level) float32 8B 2.5e+04 3e+04
1221
+ altitude (level) float32 8B 1.036e+04 9.164e+03
1222
+ Data variables:
1223
+ *empty*
1224
+
1225
+ >>> met.shape
1226
+ (20, 20, 2, 1)
1227
+
1228
+ >>> met.size
1229
+ 800
1230
+
1231
+ >>> # Fill it up with some constant data
1232
+ >>> met["temperature"] = xr.DataArray(np.full(met.shape, 234.5), coords=met.coords)
1233
+ >>> met["humidity"] = xr.DataArray(np.full(met.shape, 0.5), coords=met.coords)
1234
+ >>> met
1235
+ MetDataset with data:
1236
+ <xarray.Dataset> Size: 13kB
1237
+ Dimensions: (longitude: 20, latitude: 20, level: 2, time: 1)
1238
+ Coordinates:
1239
+ * longitude (longitude) float64 160B 0.0 0.5 1.0 1.5 ... 8.0 8.5 9.0 9.5
1240
+ * latitude (latitude) float64 160B 0.0 0.5 1.0 1.5 ... 8.0 8.5 9.0 9.5
1241
+ * level (level) float64 16B 250.0 300.0
1242
+ * time (time) datetime64[ns] 8B 2019-01-01
1243
+ air_pressure (level) float32 8B 2.5e+04 3e+04
1244
+ altitude (level) float32 8B 1.036e+04 9.164e+03
1245
+ Data variables:
1246
+ temperature (longitude, latitude, level, time) float64 6kB 234.5 ... 234.5
1247
+ humidity (longitude, latitude, level, time) float64 6kB 0.5 0.5 ... 0.5
1248
+
1249
+ >>> # Convert to a GeoVectorDataset
1250
+ >>> vector = met.to_vector()
1251
+ >>> vector.dataframe.head()
1252
+ longitude latitude level time temperature humidity
1253
+ 0 0.0 0.0 250.0 2019-01-01 234.5 0.5
1254
+ 1 0.0 0.0 300.0 2019-01-01 234.5 0.5
1255
+ 2 0.0 0.5 250.0 2019-01-01 234.5 0.5
1256
+ 3 0.0 0.5 300.0 2019-01-01 234.5 0.5
1257
+ 4 0.0 1.0 250.0 2019-01-01 234.5 0.5
1258
+ """
1259
+ input_data = {
1260
+ "longitude": longitude,
1261
+ "latitude": latitude,
1262
+ "level": level,
1263
+ "time": time,
1264
+ }
1265
+
1266
+ # clean up input into coords
1267
+ coords: dict[str, np.ndarray] = {}
1268
+ for key, val in input_data.items():
1269
+ dtype = "datetime64[ns]" if key == "time" else COORD_DTYPE
1270
+ arr: np.ndarray = np.asarray(val, dtype=dtype) # type: ignore[call-overload]
1271
+
1272
+ if arr.ndim == 0:
1273
+ arr = arr.reshape(1)
1274
+ elif arr.ndim > 1:
1275
+ raise ValueError(f"{key} has too many dimensions")
1276
+
1277
+ arr = np.sort(arr)
1278
+ if arr.size == 0:
1279
+ raise ValueError(f"Coordinate {key} must be nonempty.")
1280
+
1281
+ coords[key] = arr
1282
+
1283
+ return cls(xr.Dataset({}, coords=coords))
1284
+
1285
+ @classmethod
1286
+ def from_zarr(cls, store: Any, **kwargs: Any) -> MetDataset:
1287
+ """Create a :class:`MetDataset` from a path to a Zarr store.
1288
+
1289
+ Parameters
1290
+ ----------
1291
+ store : Any
1292
+ Path to Zarr store. Passed into :func:`xarray.open_zarr`.
1293
+ **kwargs : Any
1294
+ Other keyword only arguments passed into :func:`xarray.open_zarr`.
1295
+
1296
+ Returns
1297
+ -------
1298
+ MetDataset
1299
+ MetDataset with data from Zarr store.
1300
+ """
1301
+ kwargs.setdefault("storage_options", {"read_only": True})
1302
+ ds = xr.open_zarr(store, **kwargs)
1303
+ return cls(ds)
1304
+
1305
+
1306
+ class MetDataArray(MetBase):
1307
+ """Meteorological DataArray of single variable.
1308
+
1309
+ Wrapper around xr.DataArray to enforce certain
1310
+ variables and dimensions for internal usage.
1311
+
1312
+ Parameters
1313
+ ----------
1314
+ data : ArrayLike
1315
+ xr.DataArray or other array-like data source.
1316
+ When array-like input is provided, input ``**kwargs`` passed directly to
1317
+ xr.DataArray constructor.
1318
+ cachestore : :class:`CacheStore`, optional
1319
+ Cache datastore for staging intermediates with :meth:`save`.
1320
+ Defaults to DiskCacheStore.
1321
+ wrap_longitude : bool, optional
1322
+ Wrap data along the longitude dimension. If True, duplicate and shift longitude
1323
+ values (ie, -180 -> 180) to ensure that the longitude dimension covers the entire
1324
+ interval ``[-180, 180]``. Defaults to False.
1325
+ copy : bool, optional
1326
+ Copy `data` parameter on construction, by default `True`. If `data` is lazy-loaded
1327
+ via `dask`, this parameter has no effect. If `data` is already loaded into memory,
1328
+ a copy of the data (rather than a view) may be created if `True`.
1329
+ validate : bool, optional
1330
+ Confirm that the parameter `data` has correct specification. This automatically handled
1331
+ in the case that `copy=True`. Validation only introduces a very small overhead.
1332
+ This parameter should only be set to `False` if working with data derived from an
1333
+ existing MetDataset or :class`MetDataArray`. By default `True`.
1334
+ name : Hashable, optional
1335
+ Name of the data variable. If not specified, the name will be set to "met".
1336
+ **kwargs
1337
+ To be removed in future versions. Passed directly to xr.DataArray constructor.
1338
+
1339
+ Examples
1340
+ --------
1341
+ >>> import numpy as np
1342
+ >>> import xarray as xr
1343
+ >>> rng = np.random.default_rng(seed=456)
1344
+
1345
+ >>> # Cook up random xarray object
1346
+ >>> coords = {
1347
+ ... "longitude": np.arange(-20, 20),
1348
+ ... "latitude": np.arange(-30, 30),
1349
+ ... "level": [220, 240, 260, 280],
1350
+ ... "time": [np.datetime64("2021-08-01T12", "ns"), np.datetime64("2021-08-01T16", "ns")]
1351
+ ... }
1352
+ >>> da = xr.DataArray(rng.random((40, 60, 4, 2)), dims=coords.keys(), coords=coords)
1353
+
1354
+ >>> # Cast to `MetDataArray` in order to interpolate
1355
+ >>> from pycontrails import MetDataArray
1356
+ >>> mda = MetDataArray(da)
1357
+ >>> mda.interpolate(-11.4, 5.7, 234, np.datetime64("2021-08-01T13"))
1358
+ array([0.52358215])
1359
+
1360
+ >>> mda.interpolate(-11.4, 5.7, 234, np.datetime64("2021-08-01T13"), method='nearest')
1361
+ array([0.4188465])
1362
+
1363
+ >>> da.sel(longitude=-11, latitude=6, level=240, time=np.datetime64("2021-08-01T12")).item()
1364
+ 0.41884649899766946
1365
+ """
1366
+
1367
+ data: xr.DataArray
1368
+
1369
+ def __init__(
1370
+ self,
1371
+ data: xr.DataArray,
1372
+ cachestore: CacheStore | None = None,
1373
+ wrap_longitude: bool = False,
1374
+ copy: bool = True,
1375
+ validate: bool = True,
1376
+ name: Hashable | None = None,
1377
+ **kwargs: Any,
1378
+ ) -> None:
1379
+ # init cache
1380
+ self.cachestore = cachestore
1381
+
1382
+ # try to create DataArray out of input data and **kwargs
1383
+ if not isinstance(data, xr.DataArray):
1384
+ warnings.warn(
1385
+ "Input 'data' must be an xarray DataArray. "
1386
+ "Passing arbitrary kwargs will be removed in future versions.",
1387
+ DeprecationWarning,
1388
+ )
1389
+ data = xr.DataArray(data, **kwargs)
1390
+
1391
+ if copy:
1392
+ self.data = data.copy()
1393
+ self._preprocess_dims(wrap_longitude)
1394
+ else:
1395
+ if wrap_longitude:
1396
+ raise ValueError("Set 'copy=True' when using 'wrap_longitude=True'.")
1397
+ self.data = data
1398
+ if validate:
1399
+ self._validate_dims()
1400
+
1401
+ # Priority: name > data.name > "met"
1402
+ name = name or self.data.name or "met"
1403
+ self.data.name = name
1404
+
1405
+ @property
1406
+ def values(self) -> np.ndarray:
1407
+ """Return underlying numpy array.
1408
+
1409
+ This methods loads :attr:`data` if it is not already in memory.
1410
+
1411
+ Returns
1412
+ -------
1413
+ np.ndarray
1414
+ Underlying numpy array
1415
+
1416
+ See Also
1417
+ --------
1418
+ - :meth:`xr.Dataset.load`
1419
+ - :meth:`xr.DataArray.load`
1420
+ """
1421
+ if not self.in_memory:
1422
+ self._check_memory("Extracting numpy array from")
1423
+ self.data.load()
1424
+
1425
+ return self.data.values
1426
+
1427
+ @property
1428
+ def name(self) -> Hashable:
1429
+ """Return the DataArray name.
1430
+
1431
+ Returns
1432
+ -------
1433
+ Hashable
1434
+ DataArray name
1435
+ """
1436
+ return self.data.name
1437
+
1438
+ @property
1439
+ def binary(self) -> bool:
1440
+ """Determine if all data is a binary value (0, 1).
1441
+
1442
+ Returns
1443
+ -------
1444
+ bool
1445
+ True if all data values are binary value (0, 1)
1446
+ """
1447
+ return np.array_equal(self.data, self.data.astype(bool))
1448
+
1449
+ @property
1450
+ @overrides
1451
+ def size(self) -> int:
1452
+ return self.data.size
1453
+
1454
+ @property
1455
+ @overrides
1456
+ def shape(self) -> tuple[int, int, int, int]:
1457
+ # https://github.com/python/mypy/issues/1178
1458
+ return typing.cast(tuple[int, int, int, int], self.data.shape)
1459
+
1460
+ def copy(self) -> MetDataArray:
1461
+ """Create a copy of the current class.
1462
+
1463
+ Returns
1464
+ -------
1465
+ MetDataArray
1466
+ MetDataArray copy
1467
+ """
1468
+ return MetDataArray(self.data, cachestore=self.cachestore, copy=True)
1469
+
1470
+ def wrap_longitude(self) -> MetDataArray:
1471
+ """Wrap longitude coordinates.
1472
+
1473
+ Returns
1474
+ -------
1475
+ MetDataArray
1476
+ Copy of MetDataArray with wrapped longitude values.
1477
+ Returns copy of current MetDataArray when longitude values are already wrapped
1478
+ """
1479
+ return MetDataArray(_wrap_longitude(self.data), cachestore=self.cachestore)
1480
+
1481
+ @property
1482
+ def in_memory(self) -> bool:
1483
+ """Check if underlying :attr:`data` is loaded into memory.
1484
+
1485
+ This method uses protected attributes of underlying `xarray` objects, and may be subject
1486
+ to deprecation.
1487
+
1488
+ .. versionchanged:: 0.26.0
1489
+
1490
+ Rename from ``is_loaded`` to ``in_memory``.
1491
+
1492
+ Returns
1493
+ -------
1494
+ bool
1495
+ If underlying data exists as an `np.ndarray` in memory.
1496
+ """
1497
+ return self.data._in_memory
1498
+
1499
+ @overload
1500
+ def interpolate(
1501
+ self,
1502
+ longitude: float | npt.NDArray[np.float64],
1503
+ latitude: float | npt.NDArray[np.float64],
1504
+ level: float | npt.NDArray[np.float64],
1505
+ time: np.datetime64 | npt.NDArray[np.datetime64],
1506
+ *,
1507
+ method: str = ...,
1508
+ bounds_error: bool = ...,
1509
+ fill_value: float | np.float64 | None = ...,
1510
+ localize: bool = ...,
1511
+ lowmem: bool = ...,
1512
+ indices: interpolation.RGIArtifacts | None = ...,
1513
+ return_indices: Literal[False] = ...,
1514
+ ) -> npt.NDArray[np.float64]: ...
1515
+
1516
+ @overload
1517
+ def interpolate(
1518
+ self,
1519
+ longitude: float | npt.NDArray[np.float64],
1520
+ latitude: float | npt.NDArray[np.float64],
1521
+ level: float | npt.NDArray[np.float64],
1522
+ time: np.datetime64 | npt.NDArray[np.datetime64],
1523
+ *,
1524
+ method: str = ...,
1525
+ bounds_error: bool = ...,
1526
+ fill_value: float | np.float64 | None = ...,
1527
+ localize: bool = ...,
1528
+ lowmem: bool = ...,
1529
+ indices: interpolation.RGIArtifacts | None = ...,
1530
+ return_indices: Literal[True],
1531
+ ) -> tuple[npt.NDArray[np.float64], interpolation.RGIArtifacts]: ...
1532
+
1533
+ def interpolate(
1534
+ self,
1535
+ longitude: float | npt.NDArray[np.float64],
1536
+ latitude: float | npt.NDArray[np.float64],
1537
+ level: float | npt.NDArray[np.float64],
1538
+ time: np.datetime64 | npt.NDArray[np.datetime64],
1539
+ *,
1540
+ method: str = "linear",
1541
+ bounds_error: bool = False,
1542
+ fill_value: float | np.float64 | None = np.nan,
1543
+ localize: bool = False,
1544
+ lowmem: bool = False,
1545
+ indices: interpolation.RGIArtifacts | None = None,
1546
+ return_indices: bool = False,
1547
+ ) -> npt.NDArray[np.float64] | tuple[npt.NDArray[np.float64], interpolation.RGIArtifacts]:
1548
+ """Interpolate values over underlying DataArray.
1549
+
1550
+ Zero dimensional coordinates are reshaped to 1D arrays.
1551
+
1552
+ If ``lowmem == False``, method automatically loads underlying :attr:`data` into
1553
+ memory. Otherwise, method iterates through smaller subsets of :attr:`data` and releases
1554
+ subsets from memory once interpolation against each subset is finished.
1555
+
1556
+ If ``method == "nearest"``, the out array will have the same ``dtype`` as
1557
+ the underlying :attr:`data`.
1558
+
1559
+ If ``method == "linear"``, the out array will be promoted to the most
1560
+ precise ``dtype`` of:
1561
+
1562
+ - underlying :attr:`data`
1563
+ - :attr:`data.longitude`
1564
+ - :attr:`data.latitude`
1565
+ - :attr:`data.level`
1566
+ - ``longitude``
1567
+ - ``latitude``
1568
+
1569
+ .. versionadded:: 0.24
1570
+
1571
+ This method can now handle singleton dimensions with ``method == "linear"``.
1572
+ Previously these degenerate dimensions caused nan values to be returned.
1573
+
1574
+ Parameters
1575
+ ----------
1576
+ longitude : float | npt.NDArray[np.float64]
1577
+ Longitude values to interpolate. Assumed to be 0 or 1 dimensional.
1578
+ latitude : float | npt.NDArray[np.float64]
1579
+ Latitude values to interpolate. Assumed to be 0 or 1 dimensional.
1580
+ level : float | npt.NDArray[np.float64]
1581
+ Level values to interpolate. Assumed to be 0 or 1 dimensional.
1582
+ time : np.datetime64 | npt.NDArray[np.datetime64]
1583
+ Time values to interpolate. Assumed to be 0 or 1 dimensional.
1584
+ method: str, optional
1585
+ Additional keyword arguments to pass to
1586
+ :class:`scipy.interpolate.RegularGridInterpolator`.
1587
+ Defaults to "linear".
1588
+ bounds_error: bool, optional
1589
+ Additional keyword arguments to pass to
1590
+ :class:`scipy.interpolate.RegularGridInterpolator`.
1591
+ Defaults to ``False``.
1592
+ fill_value: float | np.float64, optional
1593
+ Additional keyword arguments to pass to
1594
+ :class:`scipy.interpolate.RegularGridInterpolator`.
1595
+ Set to None to extrapolate outside the boundary when ``method`` is ``nearest``.
1596
+ Defaults to ``np.nan``.
1597
+ localize: bool, optional
1598
+ Experimental. If True, downselect gridded data to smallest bounding box containing
1599
+ all points. By default False.
1600
+ lowmem: bool, optional
1601
+ Experimental. If True, iterate through points binned by the time coordinate of the
1602
+ grided data, and downselect gridded data to the smallest bounding box containing
1603
+ each binned set of point *before loading into memory*. This can significantly reduce
1604
+ memory consumption with large numbers of points at the cost of increased runtime.
1605
+ By default False.
1606
+ indices: tuple | None, optional
1607
+ Experimental. See :func:`interpolation.interp`. None by default.
1608
+ return_indices: bool, optional
1609
+ Experimental. See :func:`interpolation.interp`. False by default.
1610
+ Note that values returned differ when ``lowmem=True`` and ``lowmem=False``,
1611
+ so output should only be re-used in calls with the same ``lowmem`` value.
1612
+
1613
+ Returns
1614
+ -------
1615
+ np.ndarray
1616
+ Interpolated values
1617
+
1618
+ See Also
1619
+ --------
1620
+ :meth:`GeoVectorDataset.intersect_met`
1621
+
1622
+ Examples
1623
+ --------
1624
+ >>> from datetime import datetime
1625
+ >>> import numpy as np
1626
+ >>> import pandas as pd
1627
+ >>> from pycontrails.datalib.ecmwf import ERA5
1628
+
1629
+ >>> times = (datetime(2022, 3, 1, 12), datetime(2022, 3, 1, 15))
1630
+ >>> variables = "air_temperature"
1631
+ >>> levels = [200, 250, 300]
1632
+ >>> era5 = ERA5(times, variables, levels)
1633
+ >>> met = era5.open_metdataset()
1634
+ >>> mda = met["air_temperature"]
1635
+
1636
+ >>> # Interpolation at a grid point agrees with value
1637
+ >>> mda.interpolate(1, 2, 300, np.datetime64('2022-03-01T14:00'))
1638
+ array([241.91972984])
1639
+
1640
+ >>> da = mda.data
1641
+ >>> da.sel(longitude=1, latitude=2, level=300, time=np.datetime64('2022-03-01T14')).item()
1642
+ 241.9197298421629
1643
+
1644
+ >>> # Interpolation off grid
1645
+ >>> mda.interpolate(1.1, 2.1, 290, np.datetime64('2022-03-01 13:10'))
1646
+ array([239.83793798])
1647
+
1648
+ >>> # Interpolate along path
1649
+ >>> longitude = np.linspace(1, 2, 10)
1650
+ >>> latitude = np.linspace(2, 3, 10)
1651
+ >>> level = np.linspace(200, 300, 10)
1652
+ >>> time = pd.date_range("2022-03-01T14", periods=10, freq="5min")
1653
+ >>> mda.interpolate(longitude, latitude, level, time)
1654
+ array([220.44347694, 223.08900738, 225.74338924, 228.41642088,
1655
+ 231.10858599, 233.54857391, 235.71504913, 237.86478872,
1656
+ 239.99274623, 242.10792167])
1657
+
1658
+ >>> # Can easily switch to alternative low-memory implementation
1659
+ >>> mda.interpolate(longitude, latitude, level, time, lowmem=True)
1660
+ array([220.44347694, 223.08900738, 225.74338924, 228.41642088,
1661
+ 231.10858599, 233.54857391, 235.71504913, 237.86478872,
1662
+ 239.99274623, 242.10792167])
1663
+ """
1664
+ if lowmem:
1665
+ return self._interp_lowmem(
1666
+ longitude,
1667
+ latitude,
1668
+ level,
1669
+ time,
1670
+ method=method,
1671
+ bounds_error=bounds_error,
1672
+ fill_value=fill_value,
1673
+ indices=indices,
1674
+ return_indices=return_indices,
1675
+ )
1676
+
1677
+ # Load if necessary
1678
+ if not self.in_memory:
1679
+ self._check_memory("Interpolation over")
1680
+ self.data.load()
1681
+
1682
+ # Convert all inputs to 1d arrays
1683
+ # Not validating against ndim >= 2
1684
+ longitude, latitude, level, time = np.atleast_1d(longitude, latitude, level, time)
1685
+
1686
+ # Pass off to the interp function, which does all the heavy lifting
1687
+ return interpolation.interp(
1688
+ longitude=longitude,
1689
+ latitude=latitude,
1690
+ level=level,
1691
+ time=time,
1692
+ da=self.data,
1693
+ method=method,
1694
+ bounds_error=bounds_error,
1695
+ fill_value=fill_value,
1696
+ localize=localize,
1697
+ indices=indices,
1698
+ return_indices=return_indices,
1699
+ )
1700
+
1701
+ def _interp_lowmem(
1702
+ self,
1703
+ longitude: float | npt.NDArray[np.float64],
1704
+ latitude: float | npt.NDArray[np.float64],
1705
+ level: float | npt.NDArray[np.float64],
1706
+ time: np.datetime64 | npt.NDArray[np.datetime64],
1707
+ *,
1708
+ method: str = "linear",
1709
+ bounds_error: bool = False,
1710
+ fill_value: float | np.float64 | None = np.nan,
1711
+ minimize_memory: bool = False,
1712
+ indices: interpolation.RGIArtifacts | None = None,
1713
+ return_indices: bool = False,
1714
+ ) -> npt.NDArray[np.float64] | tuple[npt.NDArray[np.float64], interpolation.RGIArtifacts]:
1715
+ """Interpolate values against underlying DataArray.
1716
+
1717
+ This method is used by :meth:`interpolate` when ``lowmem=True``.
1718
+ Parameters and return types are identical to :meth:`interpolate`, except
1719
+ that the ``localize`` keyword argument is omitted.
1720
+ """
1721
+ # Convert all inputs to 1d arrays
1722
+ # Not validating against ndim >= 2
1723
+ longitude, latitude, level, time = np.atleast_1d(longitude, latitude, level, time)
1724
+
1725
+ if bounds_error:
1726
+ _lowmem_boundscheck(time, self.data)
1727
+
1728
+ # Create buffers for holding interpolation output
1729
+ # Use np.full rather than np.empty so points not covered
1730
+ # by masks are filled with correct out-of-bounds values.
1731
+ out = np.full(longitude.shape, fill_value, dtype=self.data.dtype)
1732
+ if return_indices:
1733
+ rgi_artifacts = interpolation.RGIArtifacts(
1734
+ xi_indices=np.full((4, longitude.size), -1, dtype=np.int64),
1735
+ norm_distances=np.full((4, longitude.size), np.nan, dtype=np.float64),
1736
+ out_of_bounds=np.full((longitude.size,), True, dtype=np.bool_),
1737
+ )
1738
+
1739
+ # Iterate over portions of points between adjacent time steps in gridded data
1740
+ for mask in _lowmem_masks(time, self.data["time"].values):
1741
+ if mask is None or not np.any(mask):
1742
+ continue
1743
+
1744
+ lon_sl = longitude[mask]
1745
+ lat_sl = latitude[mask]
1746
+ lev_sl = level[mask]
1747
+ t_sl = time[mask]
1748
+ if indices is not None:
1749
+ indices_sl = interpolation.RGIArtifacts(
1750
+ xi_indices=indices.xi_indices[:, mask],
1751
+ norm_distances=indices.norm_distances[:, mask],
1752
+ out_of_bounds=indices.out_of_bounds[mask],
1753
+ )
1754
+ else:
1755
+ indices_sl = None
1756
+
1757
+ coords = {"longitude": lon_sl, "latitude": lat_sl, "level": lev_sl, "time": t_sl}
1758
+ if any(np.all(np.isnan(coord)) for coord in coords.values()):
1759
+ continue
1760
+ da = interpolation._localize(self.data, coords)
1761
+ if not da._in_memory:
1762
+ logger.debug(
1763
+ "Loading %s MB subset of %s into memory.",
1764
+ round(da.nbytes / 1_000_000, 2),
1765
+ da.name,
1766
+ )
1767
+ da.load()
1768
+
1769
+ tmp = interpolation.interp(
1770
+ longitude=lon_sl,
1771
+ latitude=lat_sl,
1772
+ level=lev_sl,
1773
+ time=t_sl,
1774
+ da=da,
1775
+ method=method,
1776
+ bounds_error=bounds_error,
1777
+ fill_value=fill_value,
1778
+ localize=False, # would be no-op; da is localized already
1779
+ indices=indices_sl,
1780
+ return_indices=return_indices,
1781
+ )
1782
+
1783
+ if return_indices:
1784
+ out[mask], rgi_sl = tmp
1785
+ rgi_artifacts.xi_indices[:, mask] = rgi_sl.xi_indices
1786
+ rgi_artifacts.norm_distances[:, mask] = rgi_sl.norm_distances
1787
+ rgi_artifacts.out_of_bounds[mask] = rgi_sl.out_of_bounds
1788
+ else:
1789
+ out[mask] = tmp
1790
+
1791
+ if return_indices:
1792
+ return out, rgi_artifacts
1793
+ return out
1794
+
1795
+ def _check_memory(self, msg_start: str) -> None:
1796
+ """Check the memory usage of the underlying data.
1797
+
1798
+ If the data is larger than 4 GB, a warning is issued. If the data is
1799
+ larger than 32 GB, a RuntimeError is raised.
1800
+ """
1801
+ n_bytes = self.data.nbytes
1802
+ mb = round(n_bytes / int(1e6), 2)
1803
+ logger.debug("Loading %s into memory consumes %s MB.", self.name, mb)
1804
+
1805
+ n_gb = n_bytes // int(1e9)
1806
+ if n_gb <= 4:
1807
+ return
1808
+
1809
+ # Prevent something stupid
1810
+ msg = (
1811
+ f"{msg_start} MetDataArray {self.name} requires loading "
1812
+ f"at least {n_gb} GB of data into memory. Downselect data if possible. "
1813
+ "If working with a GeoVectorDataset instance, this can be achieved "
1814
+ "with the method 'downselect_met'."
1815
+ )
1816
+
1817
+ if n_gb > 32:
1818
+ raise RuntimeError(msg)
1819
+ warnings.warn(msg)
1820
+
1821
+ def save(self, **kwargs: Any) -> list[str]:
1822
+ """Save intermediate to :attr:`cachestore` as netcdf.
1823
+
1824
+ Load and restore using :meth:`load`.
1825
+
1826
+ Parameters
1827
+ ----------
1828
+ **kwargs : Any
1829
+ Keyword arguments passed directly to :func:`xarray.save_mfdataset`
1830
+
1831
+ Returns
1832
+ -------
1833
+ list[str]
1834
+ Returns filenames of saved files
1835
+ """
1836
+ dataset = self.data.to_dataset()
1837
+ return self._save(dataset, **kwargs)
1838
+
1839
+ @classmethod
1840
+ def load(
1841
+ cls,
1842
+ hash: str,
1843
+ cachestore: CacheStore | None = None,
1844
+ chunks: dict[str, int] | None = None,
1845
+ ) -> MetDataArray:
1846
+ """Load saved intermediate from :attr:`cachestore`.
1847
+
1848
+ Parameters
1849
+ ----------
1850
+ hash : str
1851
+ Saved hash to load.
1852
+ cachestore : CacheStore, optional
1853
+ Cache datastore to use for sourcing files.
1854
+ Defaults to DiskCacheStore.
1855
+ chunks : dict[str, int], optional
1856
+ Chunks kwarg passed to :func:`xarray.open_mfdataset()` when opening files.
1857
+
1858
+ Returns
1859
+ -------
1860
+ MetDataArray
1861
+ New MetDataArray with loaded data.
1862
+ """
1863
+ cachestore = cachestore or DiskCacheStore()
1864
+ chunks = chunks or {}
1865
+ data = _load(hash, cachestore, chunks)
1866
+ return cls(data[next(iter(data.data_vars))])
1867
+
1868
+ @property
1869
+ def proportion(self) -> float:
1870
+ """Compute proportion of points with value 1.
1871
+
1872
+ Returns
1873
+ -------
1874
+ float
1875
+ Proportion of points with value 1
1876
+
1877
+ Raises
1878
+ ------
1879
+ NotImplementedError
1880
+ If instance does not contain binary data.
1881
+ """
1882
+ if not self.binary:
1883
+ raise NotImplementedError("proportion method is only implemented for binary fields")
1884
+
1885
+ return self.data.sum().values.item() / self.data.count().values.item()
1886
+
1887
+ def find_edges(self) -> MetDataArray:
1888
+ """Find edges of regions.
1889
+
1890
+ Returns
1891
+ -------
1892
+ MetDataArray
1893
+ MetDataArray with a binary field, 1 on the edge of the regions,
1894
+ 0 outside and inside the regions.
1895
+
1896
+ Raises
1897
+ ------
1898
+ NotImplementedError
1899
+ If the instance is not binary.
1900
+ """
1901
+ if not self.binary:
1902
+ raise NotImplementedError("find_edges method is only implemented for binary fields")
1903
+
1904
+ # edge detection algorithm using differentiation to reduce the areas to lines
1905
+ def _edges(da: xr.DataArray) -> xr.DataArray:
1906
+ lat_diff = da.differentiate("latitude")
1907
+ lon_diff = da.differentiate("longitude")
1908
+ diff = da.where((lat_diff != 0) | (lon_diff != 0), 0)
1909
+
1910
+ # TODO: what is desired behavior here?
1911
+ # set boundaries to close contour regions
1912
+ diff[dict(longitude=0)] = da[dict(longitude=0)]
1913
+ diff[dict(longitude=-1)] = da[dict(longitude=-1)]
1914
+ diff[dict(latitude=0)] = da[dict(latitude=0)]
1915
+ diff[dict(latitude=-1)] = da[dict(latitude=-1)]
1916
+
1917
+ return diff
1918
+
1919
+ # load data into memory (required for value assignment in _edges()
1920
+ self.data.load()
1921
+
1922
+ data = self.data.groupby("level", squeeze=False).map(_edges)
1923
+ return MetDataArray(data, cachestore=self.cachestore)
1924
+
1925
+ def to_polygon_feature(
1926
+ self,
1927
+ level: float | int | None = None,
1928
+ time: np.datetime64 | datetime | None = None,
1929
+ fill_value: float = np.nan,
1930
+ iso_value: float | None = None,
1931
+ min_area: float = 0.0,
1932
+ epsilon: float = 0.0,
1933
+ lower_bound: bool = True,
1934
+ precision: int | None = None,
1935
+ interiors: bool = True,
1936
+ convex_hull: bool = False,
1937
+ include_altitude: bool = False,
1938
+ properties: dict[str, Any] | None = None,
1939
+ ) -> dict[str, Any]:
1940
+ """Create GeoJSON Feature artifact from spatial array on a single level and time slice.
1941
+
1942
+ Computed polygons always contain an exterior linear ring as defined by the
1943
+ `GeoJSON Polygon specification <https://www.rfc-editor.org/rfc/rfc7946.html#section-3.1.6>`.
1944
+ Polygons may also contain interior linear rings (holes). This method does not support
1945
+ nesting beyond the GeoJSON specification. See the :mod:`pycontrails.core.polygon`
1946
+ for additional polygon support.
1947
+
1948
+ .. versionchanged:: 0.25.12
1949
+
1950
+ Previous implementation include several additional parameters which have
1951
+ been removed:
1952
+
1953
+ - The ``approximate`` parameter
1954
+ - An ``path`` parameter to save output as JSON
1955
+ - Passing arbitrary kwargs to :func:`skimage.measure.find_contours`.
1956
+
1957
+ New implementation includes new parameters previously lacking:
1958
+
1959
+ - ``fill_value``
1960
+ - ``min_area``
1961
+ - ``include_altitude``
1962
+
1963
+ .. versionchanged:: 0.38.0
1964
+
1965
+ Change default value of ``epsilon`` from 0.15 to 0.
1966
+
1967
+ .. versionchanged:: 0.41.0
1968
+
1969
+ Convert continuous fields to binary fields before computing polygons.
1970
+ The parameters ``max_area`` and ``epsilon`` are now expressed in terms of
1971
+ longitude/latitude units instead of pixels.
1972
+
1973
+ Parameters
1974
+ ----------
1975
+ level : float, optional
1976
+ Level slice to create polygons.
1977
+ If the "level" coordinate is length 1, then the single level slice will be selected
1978
+ automatically.
1979
+ time : datetime, optional
1980
+ Time slice to create polygons.
1981
+ If the "time" coordinate is length 1, then the single time slice will be selected
1982
+ automatically.
1983
+ fill_value : float, optional
1984
+ Value used for filling missing data and for padding the underlying data array.
1985
+ Set to ``np.nan`` by default, which ensures that regions with missing data are
1986
+ never included in polygons.
1987
+ iso_value : float, optional
1988
+ Value in field to create iso-surface.
1989
+ Defaults to the average of the min and max value of the array. (This is the
1990
+ same convention as used by ``skimage``.)
1991
+ min_area : float, optional
1992
+ Minimum area of each polygon. Polygons with area less than ``min_area`` are
1993
+ not included in the output. The unit of this parameter is in longitude/latitude
1994
+ degrees squared. Set to 0 to omit any polygon filtering based on a minimal area
1995
+ conditional. By default, 0.0.
1996
+ epsilon : float, optional
1997
+ Control the extent to which the polygon is simplified. A value of 0 does not alter
1998
+ the geometry of the polygon. The unit of this parameter is in longitude/latitude
1999
+ degrees. By default, 0.0.
2000
+ lower_bound : bool, optional
2001
+ Whether to use ``iso_value`` as a lower or upper bound on values in polygon interiors.
2002
+ By default, True.
2003
+ precision : int, optional
2004
+ Number of decimal places to round coordinates to. If None, no rounding is performed.
2005
+ interiors : bool, optional
2006
+ If True, include interior linear rings (holes) in the output. True by default.
2007
+ convex_hull : bool, optional
2008
+ EXPERIMENTAL. If True, compute the convex hull of each polygon. Only implemented
2009
+ for depth=1. False by default. A warning is issued if the underlying algorithm
2010
+ fails to make valid polygons after computing the convex hull.
2011
+ include_altitude : bool, optional
2012
+ If True, include the array altitude [:math:`m`] as a z-coordinate in the
2013
+ `GeoJSON output <https://www.rfc-editor.org/rfc/rfc7946#section-3.1.1>`.
2014
+ False by default.
2015
+ properties : dict, optional
2016
+ Additional properties to include in the GeoJSON output. By default, None.
2017
+
2018
+ Returns
2019
+ -------
2020
+ dict[str, Any]
2021
+ Python representation of GeoJSON Feature with MultiPolygon geometry.
2022
+
2023
+ Notes
2024
+ -----
2025
+ :class:`Cocip` and :class:`CocipGrid` set some quantities to 0 and other quantities
2026
+ to ``np.nan`` in regions where no contrails form. When computing polygons from
2027
+ :class:`Cocip` or :class:`CocipGrid` output, take care that the choice of
2028
+ ``fill_value`` correctly includes or excludes contrail-free regions. See the
2029
+ :class:`Cocip` documentation for details about ``np.nan`` in model output.
2030
+
2031
+ See Also
2032
+ --------
2033
+ :meth:`to_polyhedra`
2034
+ :func:`pycontrails.core.polygons.find_multipolygons`
2035
+
2036
+ Examples
2037
+ --------
2038
+ >>> from pprint import pprint
2039
+ >>> from pycontrails.datalib.ecmwf import ERA5
2040
+ >>> era5 = ERA5("2022-03-01", variables="air_temperature", pressure_levels=250)
2041
+ >>> mda = era5.open_metdataset()["air_temperature"]
2042
+ >>> mda.shape
2043
+ (1440, 721, 1, 1)
2044
+
2045
+ >>> pprint(mda.to_polygon_feature(iso_value=239.5, precision=2, epsilon=0.1))
2046
+ {'geometry': {'coordinates': [[[[167.88, -22.5],
2047
+ [167.75, -22.38],
2048
+ [167.62, -22.5],
2049
+ [167.75, -22.62],
2050
+ [167.88, -22.5]]],
2051
+ [[[43.38, -33.5],
2052
+ [43.5, -34.12],
2053
+ [43.62, -33.5],
2054
+ [43.5, -33.38],
2055
+ [43.38, -33.5]]]],
2056
+ 'type': 'MultiPolygon'},
2057
+ 'properties': {},
2058
+ 'type': 'Feature'}
2059
+
2060
+ """
2061
+ if convex_hull and interiors:
2062
+ raise ValueError("Set 'interiors=False' to use the 'convex_hull' parameter.")
2063
+
2064
+ arr, altitude = _extract_2d_arr_and_altitude(self, level, time)
2065
+ if not include_altitude:
2066
+ altitude = None # this logic used below
2067
+ elif altitude is None:
2068
+ raise ValueError(
2069
+ "The parameter 'include_altitude' is True, but altitude is not "
2070
+ "found on MetDataArray instance. Either set altitude, or pass "
2071
+ "include_altitude=False."
2072
+ )
2073
+
2074
+ if not np.isnan(fill_value):
2075
+ np.nan_to_num(arr, copy=False, nan=fill_value)
2076
+
2077
+ # default iso_value
2078
+ if iso_value is None:
2079
+ iso_value = (np.nanmax(arr) + np.nanmin(arr)) / 2
2080
+ warnings.warn(f"The 'iso_value' parameter was not specified. Using value: {iso_value}")
2081
+
2082
+ # We'll get a nice error message if dependencies are not installed
2083
+ from pycontrails.core import polygon
2084
+
2085
+ # Convert to nested lists of coordinates for GeoJSON representation
2086
+ indexes = self.indexes
2087
+ longitude = indexes["longitude"].to_numpy()
2088
+ latitude = indexes["latitude"].to_numpy()
2089
+
2090
+ mp = polygon.find_multipolygon(
2091
+ arr,
2092
+ threshold=iso_value,
2093
+ min_area=min_area,
2094
+ epsilon=epsilon,
2095
+ lower_bound=lower_bound,
2096
+ interiors=interiors,
2097
+ convex_hull=convex_hull,
2098
+ longitude=longitude,
2099
+ latitude=latitude,
2100
+ precision=precision,
2101
+ )
2102
+
2103
+ return polygon.multipolygon_to_geojson(mp, altitude, properties)
2104
+
2105
+ def to_polygon_feature_collection(
2106
+ self,
2107
+ time: np.datetime64 | datetime | None = None,
2108
+ fill_value: float = np.nan,
2109
+ iso_value: float | None = None,
2110
+ min_area: float = 0.0,
2111
+ epsilon: float = 0.0,
2112
+ lower_bound: bool = True,
2113
+ precision: int | None = None,
2114
+ interiors: bool = True,
2115
+ convex_hull: bool = False,
2116
+ include_altitude: bool = False,
2117
+ properties: dict[str, Any] | None = None,
2118
+ ) -> dict[str, Any]:
2119
+ """Create GeoJSON FeatureCollection artifact from spatial array at time slice.
2120
+
2121
+ See the :meth:`to_polygon_feature` method for a description of the parameters.
2122
+
2123
+ Returns
2124
+ -------
2125
+ dict[str, Any]
2126
+ Python representation of GeoJSON FeatureCollection. This dictionary is
2127
+ comprised of individual GeoJON Features, one per :attr:`self.data["level"]`.
2128
+ """
2129
+ base_properties = properties or {}
2130
+ features = []
2131
+ for level in self.data["level"]:
2132
+ properties = base_properties.copy()
2133
+ properties.update(level=level.item())
2134
+ properties.update({f"level_{k}": v for k, v in self.data["level"].attrs.items()})
2135
+
2136
+ feature = self.to_polygon_feature(
2137
+ level=level,
2138
+ time=time,
2139
+ fill_value=fill_value,
2140
+ iso_value=iso_value,
2141
+ min_area=min_area,
2142
+ epsilon=epsilon,
2143
+ lower_bound=lower_bound,
2144
+ precision=precision,
2145
+ interiors=interiors,
2146
+ convex_hull=convex_hull,
2147
+ include_altitude=include_altitude,
2148
+ properties=properties,
2149
+ )
2150
+ features.append(feature)
2151
+
2152
+ return {
2153
+ "type": "FeatureCollection",
2154
+ "features": features,
2155
+ }
2156
+
2157
+ @overload
2158
+ def to_polyhedra(
2159
+ self,
2160
+ *,
2161
+ time: datetime | None = ...,
2162
+ iso_value: float = ...,
2163
+ simplify_fraction: float = ...,
2164
+ lower_bound: bool = ...,
2165
+ return_type: Literal["geojson"],
2166
+ path: str | None = ...,
2167
+ altitude_scale: float = ...,
2168
+ output_vertex_normals: bool = ...,
2169
+ closed: bool = ...,
2170
+ ) -> dict: ...
2171
+
2172
+ @overload
2173
+ def to_polyhedra(
2174
+ self,
2175
+ *,
2176
+ time: datetime | None = ...,
2177
+ iso_value: float = ...,
2178
+ simplify_fraction: float = ...,
2179
+ lower_bound: bool = ...,
2180
+ return_type: Literal["mesh"],
2181
+ path: str | None = ...,
2182
+ altitude_scale: float = ...,
2183
+ output_vertex_normals: bool = ...,
2184
+ closed: bool = ...,
2185
+ ) -> o3d.geometry.TriangleMesh: ...
2186
+
2187
+ def to_polyhedra(
2188
+ self,
2189
+ *,
2190
+ time: datetime | None = None,
2191
+ iso_value: float = 0.0,
2192
+ simplify_fraction: float = 1.0,
2193
+ lower_bound: bool = True,
2194
+ return_type: str = "geojson",
2195
+ path: str | None = None,
2196
+ altitude_scale: float = 1.0,
2197
+ output_vertex_normals: bool = False,
2198
+ closed: bool = True,
2199
+ ) -> dict | o3d.geometry.TriangleMesh:
2200
+ """Create a collection of polyhedra from spatial array corresponding to a single time slice.
2201
+
2202
+ Parameters
2203
+ ----------
2204
+ time : datetime, optional
2205
+ Time slice to create mesh.
2206
+ If the "time" coordinate is length 1, then the single time slice will be selected
2207
+ automatically.
2208
+ iso_value : float, optional
2209
+ Value in field to create iso-surface. Defaults to 0.
2210
+ simplify_fraction : float, optional
2211
+ Apply `open3d` `simplify_quadric_decimation` method to simplify the polyhedra geometry.
2212
+ This parameter must be in the half-open interval (0.0, 1.0].
2213
+ Defaults to 1.0, corresponding to no reduction.
2214
+ lower_bound : bool, optional
2215
+ Whether to use ``iso_value`` as a lower or upper bound on values in polyhedra interiors.
2216
+ By default, True.
2217
+ return_type : str, optional
2218
+ Must be one of "geojson" or "mesh". Defaults to "geojson".
2219
+ If "geojson", this method returns a dictionary representation of a geojson MultiPolygon
2220
+ object whose polygons are polyhedra faces.
2221
+ If "mesh", this method returns an `open3d` `TriangleMesh` instance.
2222
+ path : str, optional
2223
+ Output geojson or mesh to file.
2224
+ If `return_type` is "mesh", see `Open3D File I/O for Mesh
2225
+ <http://www.open3d.org/docs/release/tutorial/geometry/file_io.html#Mesh>`_ for
2226
+ file type options.
2227
+ altitude_scale : float, optional
2228
+ Rescale the altitude dimension of the mesh, [:math:`m`]
2229
+ output_vertex_normals : bool, optional
2230
+ If ``path`` is defined, write out vertex normals.
2231
+ Defaults to False.
2232
+ closed : bool, optional
2233
+ If True, pad spatial array along all axes to ensure polyhedra are "closed".
2234
+ This flag often gives rise to cleaner visualizations. Defaults to True.
2235
+
2236
+ Returns
2237
+ -------
2238
+ dict | :class:`o3d.geometry.TriangleMesh`
2239
+ Python representation of geojson object or `Open3D Triangle Mesh
2240
+ <http://www.open3d.org/docs/release/tutorial/geometry/mesh.html>`_ depending on the
2241
+ `return_type` parameter.
2242
+
2243
+ Raises
2244
+ ------
2245
+ ModuleNotFoundError
2246
+ Method requires the `vis` optional dependencies
2247
+ ValueError
2248
+ If input parameters are invalid.
2249
+
2250
+ See Also
2251
+ --------
2252
+ :meth:`to_polygons`
2253
+ `skimage.measure.marching_cubes <https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.marching_cubes>`_
2254
+
2255
+ Notes
2256
+ -----
2257
+ Uses the `scikit-image Marching Cubes <https://scikit-image.org/docs/dev/auto_examples/edges/plot_marching_cubes.html>`_
2258
+ algorithm to reconstruct a surface from the point-cloud like arrays.
2259
+ """
2260
+ try:
2261
+ from skimage import measure
2262
+ except ModuleNotFoundError as e:
2263
+ dependencies.raise_module_not_found_error(
2264
+ name="MetDataArray.to_polyhedra method",
2265
+ package_name="scikit-image",
2266
+ pycontrails_optional_package="vis",
2267
+ module_not_found_error=e,
2268
+ )
2269
+
2270
+ try:
2271
+ import open3d as o3d
2272
+ except ModuleNotFoundError as e:
2273
+ dependencies.raise_module_not_found_error(
2274
+ name="MetDataArray.to_polyhedra method",
2275
+ package_name="open3d",
2276
+ pycontrails_optional_package="open3d",
2277
+ module_not_found_error=e,
2278
+ )
2279
+
2280
+ if len(self.data["level"]) == 1:
2281
+ raise ValueError(
2282
+ "Found single `level` coordinate in DataArray. This method requires at least two."
2283
+ )
2284
+
2285
+ # select time
2286
+ if time is None and len(self.data["time"]) == 1:
2287
+ time = self.data["time"].values[0]
2288
+
2289
+ if time is None:
2290
+ raise ValueError(
2291
+ "time input must be defined when the length of the time coordinates are > 1"
2292
+ )
2293
+
2294
+ if simplify_fraction > 1 or simplify_fraction <= 0:
2295
+ raise ValueError("Parameter `simplify_fraction` must be in the interval (0, 1].")
2296
+
2297
+ return_types = ["geojson", "mesh"]
2298
+ if return_type not in return_types:
2299
+ raise ValueError(f"Parameter `return_type` must be one of {', '.join(return_types)}")
2300
+
2301
+ # 3d array of longitude, latitude, altitude values
2302
+ volume = self.data.sel(time=time).values
2303
+
2304
+ # invert if iso_value is an upper bound on interior values
2305
+ if not lower_bound:
2306
+ volume = -volume
2307
+ iso_value = -iso_value
2308
+
2309
+ # convert from array index back to coordinates
2310
+ longitude = self.indexes["longitude"].values
2311
+ latitude = self.indexes["latitude"].values
2312
+ altitude = units.pl_to_m(self.indexes["level"].values)
2313
+
2314
+ # Pad volume on all axes to close the volumes
2315
+ if closed:
2316
+ # pad values to domain
2317
+ longitude0 = longitude[0] - (longitude[1] - longitude[0])
2318
+ longitude1 = longitude[-1] + longitude[-1] - longitude[-2]
2319
+ longitude = np.pad(longitude, pad_width=1, constant_values=(longitude0, longitude1))
2320
+
2321
+ latitude0 = latitude[0] - (latitude[1] - latitude[0])
2322
+ latitude1 = latitude[-1] + latitude[-1] - latitude[-2]
2323
+ latitude = np.pad(latitude, pad_width=1, constant_values=(latitude0, latitude1))
2324
+
2325
+ altitude0 = altitude[0] - (altitude[1] - altitude[0])
2326
+ altitude1 = altitude[-1] + altitude[-1] - altitude[-2]
2327
+ altitude = np.pad(altitude, pad_width=1, constant_values=(altitude0, altitude1))
2328
+
2329
+ # Pad along axes to ensure polygons are closed
2330
+ volume = np.pad(volume, pad_width=1, constant_values=iso_value)
2331
+
2332
+ # Use marching cubes to obtain the surface mesh
2333
+ # Coordinates of verts are indexes of volume array
2334
+ verts, faces, normals, _ = measure.marching_cubes(
2335
+ volume, iso_value, allow_degenerate=False, gradient_direction="ascent"
2336
+ )
2337
+
2338
+ # Convert from indexes to longitude, latitude, altitude values
2339
+ verts[:, 0] = longitude[verts[:, 0].astype(int)]
2340
+ verts[:, 1] = latitude[verts[:, 1].astype(int)]
2341
+ verts[:, 2] = altitude[verts[:, 2].astype(int)]
2342
+
2343
+ # rescale altitude
2344
+ verts[:, 2] = verts[:, 2] * altitude_scale
2345
+
2346
+ # create mesh in open3d
2347
+ mesh = o3d.geometry.TriangleMesh()
2348
+ mesh.vertices = o3d.utility.Vector3dVector(verts)
2349
+ mesh.triangles = o3d.utility.Vector3iVector(faces)
2350
+ mesh.vertex_normals = o3d.utility.Vector3dVector(normals)
2351
+
2352
+ # simplify mesh according to sim
2353
+ if simplify_fraction < 1:
2354
+ target_n_triangles = int(faces.shape[0] * simplify_fraction)
2355
+ mesh = mesh.simplify_quadric_decimation(target_number_of_triangles=target_n_triangles)
2356
+ mesh.compute_vertex_normals()
2357
+
2358
+ if path is not None:
2359
+ path = str(pathlib.Path(path).absolute())
2360
+
2361
+ if return_type == "geojson":
2362
+ verts = np.round(
2363
+ np.asarray(mesh.vertices), decimals=4
2364
+ ) # rounding to reduce the size of resultant json arrays
2365
+ faces = np.asarray(mesh.triangles)
2366
+
2367
+ # TODO: technically this is not valid GeoJSON because each polygon (triangle)
2368
+ # does not have the last element equal to the first (not a linear ring)
2369
+ # but it still works for now in Deck.GL
2370
+ coords = [[verts[face].tolist()] for face in faces]
2371
+
2372
+ geojson = {
2373
+ "type": "Feature",
2374
+ "properties": {},
2375
+ "geometry": {"type": "MultiPolygon", "coordinates": coords},
2376
+ }
2377
+
2378
+ if path is not None:
2379
+ if not path.endswith(".json"):
2380
+ path += ".json"
2381
+ with open(path, "w") as file:
2382
+ json.dump(geojson, file)
2383
+ return geojson
2384
+
2385
+ if path is not None:
2386
+ o3d.io.write_triangle_mesh(
2387
+ path,
2388
+ mesh,
2389
+ write_ascii=False,
2390
+ compressed=True,
2391
+ write_vertex_normals=output_vertex_normals,
2392
+ write_vertex_colors=False,
2393
+ write_triangle_uvs=True,
2394
+ print_progress=False,
2395
+ )
2396
+ return mesh
2397
+
2398
+ @overrides
2399
+ def broadcast_coords(self, name: str) -> xr.DataArray:
2400
+ da = xr.ones_like(self.data) * self.data[name]
2401
+ da.name = name
2402
+
2403
+ return da
2404
+
2405
+ @overrides
2406
+ def downselect(self, bbox: tuple[float, ...]) -> MetDataArray:
2407
+ data = downselect(self.data, bbox)
2408
+ return MetDataArray(data, cachestore=self.cachestore)
2409
+
2410
+
2411
+ def _is_wrapped(longitude: np.ndarray) -> bool:
2412
+ """Check if ``longitude`` covers ``[-180, 180]``."""
2413
+ return longitude[0] <= -180.0 and longitude[-1] >= 180.0
2414
+
2415
+
2416
+ def _is_zarr(ds: xr.Dataset | xr.DataArray) -> bool:
2417
+ """Check if ``ds`` appears to be Zarr-based.
2418
+
2419
+ Neither ``xarray`` nor ``dask`` readily expose such information, so
2420
+ implementation is very narrow and brittle.
2421
+ """
2422
+ if isinstance(ds, xr.Dataset):
2423
+ # Attempt 1: xarray binds the zarr close function to the instance
2424
+ # This gives us an indicator of the data is zarr based
2425
+ try:
2426
+ if ds._close.__func__.__qualname__ == "ZarrStore.close": # type: ignore
2427
+ return True
2428
+ except AttributeError:
2429
+ pass
2430
+
2431
+ # Grab the first data variable, get underlying DataArray
2432
+ ds = ds[next(iter(ds.data_vars))]
2433
+
2434
+ # Attempt 2: Examine the dask graph
2435
+ darr = ds.variable._data # dask array in some cases
2436
+ try:
2437
+ # Get the first dask instruction
2438
+ dask0 = darr.dask[next(iter(darr.dask))] # type: ignore[union-attr]
2439
+ except AttributeError:
2440
+ return False
2441
+ return dask0.array.array.array.__class__.__name__ == "ZarrArrayWrapper"
2442
+
2443
+
2444
+ def shift_longitude(data: XArrayType, bound: float = -180.0) -> XArrayType:
2445
+ """Shift longitude values from any input domain to [bound, 360 + bound) domain.
2446
+
2447
+ Sorts data by ascending longitude values.
2448
+
2449
+
2450
+ Parameters
2451
+ ----------
2452
+ data : XArrayType
2453
+ :class:`xr.Dataset` or :class:`xr.DataArray` with longitude dimension
2454
+ bound : float, optional
2455
+ Lower bound of the domain.
2456
+ Output domain will be [bound, 360 + bound).
2457
+ Defaults to -180, which results in longitude domain [-180, 180).
2458
+
2459
+
2460
+ Returns
2461
+ -------
2462
+ XArrayType
2463
+ :class:`xr.Dataset` or :class:`xr.DataArray` with longitude values on [a, 360 + a).
2464
+ """
2465
+ return data.assign_coords(
2466
+ longitude=((data["longitude"].values - bound) % 360.0) + bound
2467
+ ).sortby("longitude", ascending=True)
2468
+
2469
+
2470
+ def _wrap_longitude(data: XArrayType) -> XArrayType:
2471
+ """Wrap longitude grid coordinates.
2472
+
2473
+ This function assumes the longitude dimension on ``data``:
2474
+
2475
+ - is sorted
2476
+ - is contained in [-180, 180]
2477
+
2478
+ These assumptions are checked by :class:`MetDataset` and :class`MetDataArray`
2479
+ constructors.
2480
+
2481
+ .. versionchanged:: 0.26.0
2482
+
2483
+ This function now ensures every value in the interval ``[-180, 180]``
2484
+ is covered by the longitude dimension of the returned object. See
2485
+ :meth:`MetDataset.is_wrapped` for more details.
2486
+
2487
+
2488
+ Parameters
2489
+ ----------
2490
+ data : XArrayType
2491
+ :class:`xr.Dataset` or :class:`xr.DataArray` with longitude dimension
2492
+
2493
+ Returns
2494
+ -------
2495
+ XArrayType
2496
+ Copy of :class:`xr.Dataset` or :class:`xr.DataArray` with wrapped longitude values.
2497
+
2498
+ Raises
2499
+ ------
2500
+ ValueError
2501
+ If longitude values are already wrapped.
2502
+ """
2503
+ lon = data._indexes["longitude"].index.to_numpy() # type: ignore[attr-defined]
2504
+ if _is_wrapped(lon):
2505
+ raise ValueError("Longitude values are already wrapped")
2506
+
2507
+ lon0 = lon[0]
2508
+ lon1 = lon[-1]
2509
+
2510
+ # Try to prevent something stupid
2511
+ if lon1 - lon0 < 330.0:
2512
+ warnings.warn("Wrapping longitude will create a large spatial gap of more than 30 degrees.")
2513
+
2514
+ objs = [data]
2515
+ if lon0 > -180.0: # if the lowest longitude is not low enough, duplicate highest
2516
+ dup1 = data.sel(longitude=[lon1]).assign_coords(longitude=[lon1 - 360.0])
2517
+ objs.insert(0, dup1)
2518
+ if lon1 < 180.0: # if the highest longitude is not highest enough, duplicate lowest
2519
+ dup0 = data.sel(longitude=[lon0]).assign_coords(longitude=[lon0 + 360.0])
2520
+ objs.append(dup0)
2521
+
2522
+ # Because we explicitly raise a ValueError if longitude already wrapped,
2523
+ # we know that len(objs) > 1, so the concatenation here is nontrivial
2524
+ wrapped = xr.concat(objs, dim="longitude")
2525
+ wrapped["longitude"] = wrapped["longitude"].astype(lon.dtype, copy=False)
2526
+
2527
+ # If there is only one longitude chunk in parameter data, increment
2528
+ # data.chunks can be None, using getattr for extra safety
2529
+ # NOTE: This probably doesn't seem to play well with Zarr data ...
2530
+ # we don't want to be rechunking them.
2531
+ # Ideally we'd raise if data was Zarr-based
2532
+ chunks = getattr(data, "chunks", None) or {}
2533
+ chunks = dict(chunks) # chunks can be frozen
2534
+ lon_chunks = chunks.get("longitude", ())
2535
+ if len(lon_chunks) == 1:
2536
+ chunks["longitude"] = (lon_chunks[0] + len(objs) - 1,)
2537
+ wrapped = wrapped.chunk(chunks)
2538
+
2539
+ return wrapped
2540
+
2541
+
2542
+ def _extract_2d_arr_and_altitude(
2543
+ mda: MetDataArray,
2544
+ level: float | int | None,
2545
+ time: np.datetime64 | datetime | None,
2546
+ ) -> tuple[np.ndarray, float | None]:
2547
+ """Extract underlying 2D array indexed by longitude and latitude.
2548
+
2549
+ Parameters
2550
+ ----------
2551
+ mda : MetDataArray
2552
+ MetDataArray to extract from
2553
+ level : float | int | None
2554
+ Pressure level to slice at
2555
+ time : np.datetime64 | datetime
2556
+ Time to slice at
2557
+
2558
+ Returns
2559
+ -------
2560
+ arr : np.ndarray
2561
+ Copy of 2D array at given level and time
2562
+ altitude : float | None
2563
+ Altitude of slice [:math:`m`]. None if "altitude" not found on ``mda``
2564
+ (ie, for surface level :class:`MetDataArray`).
2565
+ """
2566
+ # Determine level if not specified
2567
+ if level is None:
2568
+ level_coord = mda.indexes["level"].values
2569
+ if len(level_coord) == 1:
2570
+ level = level_coord[0]
2571
+ else:
2572
+ raise ValueError(
2573
+ "Parameter 'level' must be defined when the length of the 'level' "
2574
+ "coordinates is not 1."
2575
+ )
2576
+
2577
+ # Determine time if not specified
2578
+ if time is None:
2579
+ time_coord = mda.indexes["time"].values
2580
+ if len(time_coord) == 1:
2581
+ time = time_coord[0]
2582
+ else:
2583
+ raise ValueError(
2584
+ "Parameter 'time' must be defined when the length of the 'time' "
2585
+ "coordinates is not 1."
2586
+ )
2587
+
2588
+ da = mda.data.sel(level=level, time=time)
2589
+ arr = da.values.copy()
2590
+ if arr.ndim != 2:
2591
+ raise RuntimeError("Malformed data array")
2592
+
2593
+ try:
2594
+ altitude = da["altitude"].values.item() # item not implemented on dask arrays
2595
+ except KeyError:
2596
+ altitude = None
2597
+ else:
2598
+ altitude = round(altitude)
2599
+
2600
+ return arr, altitude
2601
+
2602
+
2603
+ def downselect(data: XArrayType, bbox: tuple[float, ...]) -> XArrayType:
2604
+ """Downselect :class:`xr.Dataset` or :class:`xr.DataArray` with spatial bounding box.
2605
+
2606
+ Parameters
2607
+ ----------
2608
+ data : XArrayType
2609
+ xr.Dataset or xr.DataArray to downselect
2610
+ bbox : tuple[float, ...]
2611
+ Tuple of coordinates defining a spatial bounding box in WGS84 coordinates.
2612
+
2613
+ - For 2D queries, ``bbox`` takes the form ``(west, south, east, north)``
2614
+ - For 3D queries, ``bbox`` takes the form
2615
+ ``(west, south, min-level, east, north, max-level)``
2616
+
2617
+ with level defined in [:math:`hPa`].
2618
+
2619
+ Returns
2620
+ -------
2621
+ XArrayType
2622
+ Downselected xr.Dataset or xr.DataArray
2623
+
2624
+ Raises
2625
+ ------
2626
+ ValueError
2627
+ If parameter ``bbox`` has wrong length.
2628
+ """
2629
+ if len(bbox) == 4:
2630
+ west, south, east, north = bbox
2631
+ level_min = -np.inf
2632
+ level_max = np.inf
2633
+
2634
+ elif len(bbox) == 6:
2635
+ west, south, level_min, east, north, level_max = bbox
2636
+
2637
+ else:
2638
+ raise ValueError(
2639
+ f"bbox {bbox} is not length 4 [west, south, east, north] "
2640
+ "or length 6 [west, south, min-level, east, north, max-level]"
2641
+ )
2642
+
2643
+ cond = (
2644
+ (data["latitude"] >= south)
2645
+ & (data["latitude"] <= north)
2646
+ & (data["level"] >= level_min)
2647
+ & (data["level"] <= level_max)
2648
+ )
2649
+
2650
+ # wrapping longitude
2651
+ if west <= east:
2652
+ cond = cond & (data["longitude"] >= west) & (data["longitude"] <= east)
2653
+ else:
2654
+ cond = cond & ((data["longitude"] >= west) | (data["longitude"] <= east))
2655
+
2656
+ return data.where(cond, drop=True)
2657
+
2658
+
2659
+ def standardize_variables(ds: DatasetType, variables: Iterable[MetVariable]) -> DatasetType:
2660
+ """Rename all variables in dataset from short name to standard name.
2661
+
2662
+ This function does not change any variables in ``ds`` that are not found in ``variables``.
2663
+
2664
+ When there are multiple variables with the same short name, the last one is used.
2665
+
2666
+ Parameters
2667
+ ----------
2668
+ ds : DatasetType
2669
+ An :class:`xr.Dataset` or :class:`MetDataset`. When a :class:`MetDataset` is
2670
+ passed, the underlying :class:`xr.Dataset` is modified in place.
2671
+ variables : Iterable[MetVariable]
2672
+ Data source variables
2673
+
2674
+ Returns
2675
+ -------
2676
+ DatasetType
2677
+ Dataset with variables renamed to standard names
2678
+ """
2679
+ if isinstance(ds, xr.Dataset):
2680
+ return _standardize_variables(ds, variables)
2681
+
2682
+ ds.data = _standardize_variables(ds.data, variables)
2683
+ return ds
2684
+
2685
+
2686
+ def _standardize_variables(ds: xr.Dataset, variables: Iterable[MetVariable]) -> xr.Dataset:
2687
+ variables_dict: dict[Hashable, str] = {v.short_name: v.standard_name for v in variables}
2688
+ name_dict = {var: variables_dict[var] for var in ds.data_vars if var in variables_dict}
2689
+ return ds.rename(name_dict)
2690
+
2691
+
2692
+ def originates_from_ecmwf(met: MetDataset | MetDataArray) -> bool:
2693
+ """Check if data appears to have originated from an ECMWF source.
2694
+
2695
+ .. versionadded:: 0.27.0
2696
+
2697
+ Experimental. Implementation is brittle.
2698
+
2699
+ Parameters
2700
+ ----------
2701
+ met : MetDataset | MetDataArray
2702
+ Dataset or array to inspect.
2703
+
2704
+ Returns
2705
+ -------
2706
+ bool
2707
+ True if data appears to be derived from an ECMWF source.
2708
+
2709
+ See Also
2710
+ --------
2711
+ - :class:`ERA5`
2712
+ - :class:`HRES`
2713
+
2714
+ """
2715
+ if isinstance(met, MetDataset):
2716
+ try:
2717
+ return met.provider_attr == "ECMWF"
2718
+ except KeyError:
2719
+ pass
2720
+ return "ecmwf" in met.attrs.get("history", "")
2721
+
2722
+
2723
+ def _load(hash: str, cachestore: CacheStore, chunks: dict[str, int]) -> xr.Dataset:
2724
+ """Load xarray data from hash.
2725
+
2726
+ Parameters
2727
+ ----------
2728
+ hash : str
2729
+ Description
2730
+ cachestore : CacheStore
2731
+ Description
2732
+ chunks : dict[str, int]
2733
+ Description
2734
+
2735
+ Returns
2736
+ -------
2737
+ xr.Dataset
2738
+ Description
2739
+ """
2740
+ disk_path = cachestore.get(f"{hash}*.nc")
2741
+ return xr.open_mfdataset(disk_path, chunks=chunks)
2742
+
2743
+
2744
+ def _add_vertical_coords(data: XArrayType) -> XArrayType:
2745
+ """Add "air_pressure" and "altitude" coordinates to data.
2746
+
2747
+ .. versionchanged:: 0.52.1
2748
+ Ensure that the ``dtype`` of the additional vertical coordinates agree
2749
+ with the ``dtype`` of the underlying gridded data.
2750
+ """
2751
+
2752
+ data["level"].attrs.update(units="hPa", long_name="Pressure", positive="down")
2753
+
2754
+ # XXX: use the dtype of the data to determine the precision of these coordinates
2755
+ # There are two competing conventions here:
2756
+ # - coordinate data should be float64
2757
+ # - gridded data is typically float32
2758
+ # - air_pressure and altitude often play both roles
2759
+ # It is more important for air_pressure and altitude to be grid-aligned than to be
2760
+ # coordinate-aligned, so we use the dtype of the data to determine the precision of
2761
+ # these coordinates
2762
+ dtype = (
2763
+ np.result_type(*data.data_vars.values(), np.float32)
2764
+ if isinstance(data, xr.Dataset)
2765
+ else data.dtype
2766
+ )
2767
+
2768
+ level = data["level"].values
2769
+
2770
+ if "air_pressure" not in data.coords:
2771
+ data = data.assign_coords(air_pressure=("level", level * 100.0))
2772
+ data.coords["air_pressure"].attrs.update(
2773
+ standard_name=AirPressure.standard_name,
2774
+ long_name=AirPressure.long_name,
2775
+ units=AirPressure.units,
2776
+ )
2777
+ if data.coords["air_pressure"].dtype != dtype:
2778
+ data.coords["air_pressure"] = data.coords["air_pressure"].astype(dtype, copy=False)
2779
+
2780
+ if "altitude" not in data.coords:
2781
+ data = data.assign_coords(altitude=("level", units.pl_to_m(level)))
2782
+ data.coords["altitude"].attrs.update(
2783
+ standard_name=Altitude.standard_name,
2784
+ long_name=Altitude.long_name,
2785
+ units=Altitude.units,
2786
+ )
2787
+ if data.coords["altitude"].dtype != dtype:
2788
+ data.coords["altitude"] = data.coords["altitude"].astype(dtype, copy=False)
2789
+
2790
+ return data
2791
+
2792
+
2793
+ def _lowmem_boundscheck(time: npt.NDArray[np.datetime64], da: xr.DataArray) -> None:
2794
+ """Extra bounds check required with low-memory interpolation strategy.
2795
+
2796
+ Because the main loop in `_interp_lowmem` processes points between time steps
2797
+ in gridded data, it will never encounter points that are out-of-bounds in time
2798
+ and may fail to produce requested out-of-bounds errors.
2799
+ """
2800
+ da_time = da["time"].to_numpy()
2801
+ if not np.all((time >= da_time.min()) & (time <= da_time.max())):
2802
+ axis = da.get_axis_num("time")
2803
+ msg = f"One of the requested xi is out of bounds in dimension {axis}"
2804
+ raise ValueError(msg)
2805
+
2806
+
2807
+ def _lowmem_masks(
2808
+ time: npt.NDArray[np.datetime64], t_met: npt.NDArray[np.datetime64]
2809
+ ) -> Generator[npt.NDArray[np.bool_], None, None]:
2810
+ """Generate sequence of masks for low-memory interpolation."""
2811
+ t_met_max = t_met.max()
2812
+ t_met_min = t_met.min()
2813
+ inbounds = (time >= t_met_min) & (time <= t_met_max)
2814
+ if not np.any(inbounds):
2815
+ return
2816
+
2817
+ earliest = np.nanmin(time)
2818
+ istart = 0 if earliest < t_met_min else np.flatnonzero(t_met <= earliest).max()
2819
+ latest = np.nanmax(time)
2820
+ iend = t_met.size - 1 if latest > t_met_max else np.flatnonzero(t_met >= latest).min()
2821
+ if istart == iend:
2822
+ yield inbounds
2823
+ return
2824
+
2825
+ # Sequence of masks covers elements in time in the interval [t_met[istart], t_met[iend]].
2826
+ # The first iteration masks elements in the interval [t_met[istart], t_met[istart+1]]
2827
+ # (inclusive of both endpoints).
2828
+ # Subsequent iterations mask elements in the interval (t_met[i], t_met[i+1]]
2829
+ # (inclusive of right endpoint only).
2830
+ for i in range(istart, iend):
2831
+ mask = ((time >= t_met[i]) if i == istart else (time > t_met[i])) & (time <= t_met[i + 1])
2832
+ if np.any(mask):
2833
+ yield mask