pycontrails 0.59.0__cp314-cp314-macosx_10_15_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pycontrails might be problematic. Click here for more details.

Files changed (123) hide show
  1. pycontrails/__init__.py +70 -0
  2. pycontrails/_version.py +34 -0
  3. pycontrails/core/__init__.py +30 -0
  4. pycontrails/core/aircraft_performance.py +679 -0
  5. pycontrails/core/airports.py +228 -0
  6. pycontrails/core/cache.py +889 -0
  7. pycontrails/core/coordinates.py +174 -0
  8. pycontrails/core/fleet.py +483 -0
  9. pycontrails/core/flight.py +2185 -0
  10. pycontrails/core/flightplan.py +228 -0
  11. pycontrails/core/fuel.py +140 -0
  12. pycontrails/core/interpolation.py +702 -0
  13. pycontrails/core/met.py +2936 -0
  14. pycontrails/core/met_var.py +387 -0
  15. pycontrails/core/models.py +1321 -0
  16. pycontrails/core/polygon.py +549 -0
  17. pycontrails/core/rgi_cython.cpython-314-darwin.so +0 -0
  18. pycontrails/core/vector.py +2249 -0
  19. pycontrails/datalib/__init__.py +12 -0
  20. pycontrails/datalib/_met_utils/metsource.py +746 -0
  21. pycontrails/datalib/ecmwf/__init__.py +73 -0
  22. pycontrails/datalib/ecmwf/arco_era5.py +345 -0
  23. pycontrails/datalib/ecmwf/common.py +114 -0
  24. pycontrails/datalib/ecmwf/era5.py +554 -0
  25. pycontrails/datalib/ecmwf/era5_model_level.py +490 -0
  26. pycontrails/datalib/ecmwf/hres.py +804 -0
  27. pycontrails/datalib/ecmwf/hres_model_level.py +466 -0
  28. pycontrails/datalib/ecmwf/ifs.py +287 -0
  29. pycontrails/datalib/ecmwf/model_levels.py +435 -0
  30. pycontrails/datalib/ecmwf/static/model_level_dataframe_v20240418.csv +139 -0
  31. pycontrails/datalib/ecmwf/variables.py +268 -0
  32. pycontrails/datalib/geo_utils.py +261 -0
  33. pycontrails/datalib/gfs/__init__.py +28 -0
  34. pycontrails/datalib/gfs/gfs.py +656 -0
  35. pycontrails/datalib/gfs/variables.py +104 -0
  36. pycontrails/datalib/goes.py +764 -0
  37. pycontrails/datalib/gruan.py +343 -0
  38. pycontrails/datalib/himawari/__init__.py +27 -0
  39. pycontrails/datalib/himawari/header_struct.py +266 -0
  40. pycontrails/datalib/himawari/himawari.py +671 -0
  41. pycontrails/datalib/landsat.py +589 -0
  42. pycontrails/datalib/leo_utils/__init__.py +5 -0
  43. pycontrails/datalib/leo_utils/correction.py +266 -0
  44. pycontrails/datalib/leo_utils/landsat_metadata.py +300 -0
  45. pycontrails/datalib/leo_utils/search.py +250 -0
  46. pycontrails/datalib/leo_utils/sentinel_metadata.py +748 -0
  47. pycontrails/datalib/leo_utils/static/bq_roi_query.sql +6 -0
  48. pycontrails/datalib/leo_utils/vis.py +59 -0
  49. pycontrails/datalib/sentinel.py +650 -0
  50. pycontrails/datalib/spire/__init__.py +5 -0
  51. pycontrails/datalib/spire/exceptions.py +62 -0
  52. pycontrails/datalib/spire/spire.py +604 -0
  53. pycontrails/ext/bada.py +42 -0
  54. pycontrails/ext/cirium.py +14 -0
  55. pycontrails/ext/empirical_grid.py +140 -0
  56. pycontrails/ext/synthetic_flight.py +431 -0
  57. pycontrails/models/__init__.py +1 -0
  58. pycontrails/models/accf.py +425 -0
  59. pycontrails/models/apcemm/__init__.py +8 -0
  60. pycontrails/models/apcemm/apcemm.py +983 -0
  61. pycontrails/models/apcemm/inputs.py +226 -0
  62. pycontrails/models/apcemm/static/apcemm_yaml_template.yaml +183 -0
  63. pycontrails/models/apcemm/utils.py +437 -0
  64. pycontrails/models/cocip/__init__.py +29 -0
  65. pycontrails/models/cocip/cocip.py +2742 -0
  66. pycontrails/models/cocip/cocip_params.py +305 -0
  67. pycontrails/models/cocip/cocip_uncertainty.py +291 -0
  68. pycontrails/models/cocip/contrail_properties.py +1530 -0
  69. pycontrails/models/cocip/output_formats.py +2270 -0
  70. pycontrails/models/cocip/radiative_forcing.py +1260 -0
  71. pycontrails/models/cocip/radiative_heating.py +520 -0
  72. pycontrails/models/cocip/unterstrasser_wake_vortex.py +508 -0
  73. pycontrails/models/cocip/wake_vortex.py +396 -0
  74. pycontrails/models/cocip/wind_shear.py +120 -0
  75. pycontrails/models/cocipgrid/__init__.py +9 -0
  76. pycontrails/models/cocipgrid/cocip_grid.py +2552 -0
  77. pycontrails/models/cocipgrid/cocip_grid_params.py +138 -0
  78. pycontrails/models/dry_advection.py +602 -0
  79. pycontrails/models/emissions/__init__.py +21 -0
  80. pycontrails/models/emissions/black_carbon.py +599 -0
  81. pycontrails/models/emissions/emissions.py +1353 -0
  82. pycontrails/models/emissions/ffm2.py +336 -0
  83. pycontrails/models/emissions/static/default-engine-uids.csv +239 -0
  84. pycontrails/models/emissions/static/edb-gaseous-v29b-engines.csv +596 -0
  85. pycontrails/models/emissions/static/edb-nvpm-v29b-engines.csv +215 -0
  86. pycontrails/models/extended_k15.py +1327 -0
  87. pycontrails/models/humidity_scaling/__init__.py +37 -0
  88. pycontrails/models/humidity_scaling/humidity_scaling.py +1075 -0
  89. pycontrails/models/humidity_scaling/quantiles/era5-model-level-quantiles.pq +0 -0
  90. pycontrails/models/humidity_scaling/quantiles/era5-pressure-level-quantiles.pq +0 -0
  91. pycontrails/models/issr.py +210 -0
  92. pycontrails/models/pcc.py +326 -0
  93. pycontrails/models/pcr.py +154 -0
  94. pycontrails/models/ps_model/__init__.py +18 -0
  95. pycontrails/models/ps_model/ps_aircraft_params.py +381 -0
  96. pycontrails/models/ps_model/ps_grid.py +701 -0
  97. pycontrails/models/ps_model/ps_model.py +1000 -0
  98. pycontrails/models/ps_model/ps_operational_limits.py +525 -0
  99. pycontrails/models/ps_model/static/ps-aircraft-params-20250328.csv +69 -0
  100. pycontrails/models/ps_model/static/ps-synonym-list-20250328.csv +104 -0
  101. pycontrails/models/sac.py +442 -0
  102. pycontrails/models/tau_cirrus.py +183 -0
  103. pycontrails/physics/__init__.py +1 -0
  104. pycontrails/physics/constants.py +117 -0
  105. pycontrails/physics/geo.py +1138 -0
  106. pycontrails/physics/jet.py +968 -0
  107. pycontrails/physics/static/iata-cargo-load-factors-20250221.csv +74 -0
  108. pycontrails/physics/static/iata-passenger-load-factors-20250221.csv +74 -0
  109. pycontrails/physics/thermo.py +551 -0
  110. pycontrails/physics/units.py +472 -0
  111. pycontrails/py.typed +0 -0
  112. pycontrails/utils/__init__.py +1 -0
  113. pycontrails/utils/dependencies.py +66 -0
  114. pycontrails/utils/iteration.py +13 -0
  115. pycontrails/utils/json.py +187 -0
  116. pycontrails/utils/temp.py +50 -0
  117. pycontrails/utils/types.py +163 -0
  118. pycontrails-0.59.0.dist-info/METADATA +179 -0
  119. pycontrails-0.59.0.dist-info/RECORD +123 -0
  120. pycontrails-0.59.0.dist-info/WHEEL +6 -0
  121. pycontrails-0.59.0.dist-info/licenses/LICENSE +178 -0
  122. pycontrails-0.59.0.dist-info/licenses/NOTICE +43 -0
  123. pycontrails-0.59.0.dist-info/top_level.txt +3 -0
@@ -0,0 +1,2936 @@
1
+ """Meteorology data models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import json
7
+ import logging
8
+ import pathlib
9
+ import sys
10
+ import typing
11
+ import warnings
12
+ from abc import ABC, abstractmethod
13
+ from collections.abc import (
14
+ Generator,
15
+ Hashable,
16
+ Iterable,
17
+ Iterator,
18
+ Mapping,
19
+ MutableMapping,
20
+ Sequence,
21
+ )
22
+ from contextlib import ExitStack
23
+ from datetime import datetime
24
+ from typing import (
25
+ TYPE_CHECKING,
26
+ Any,
27
+ Generic,
28
+ Literal,
29
+ Self,
30
+ TypeVar,
31
+ overload,
32
+ )
33
+
34
+ if sys.version_info >= (3, 12):
35
+ from typing import override
36
+ else:
37
+ from typing_extensions import override
38
+
39
+ import numpy as np
40
+ import numpy.typing as npt
41
+ import pandas as pd
42
+ import xarray as xr
43
+
44
+ from pycontrails.core import interpolation
45
+ from pycontrails.core import vector as vector_module
46
+ from pycontrails.core.cache import CacheStore, DiskCacheStore
47
+ from pycontrails.core.met_var import AirPressure, Altitude, MetVariable
48
+ from pycontrails.physics import units
49
+ from pycontrails.utils import dependencies
50
+ from pycontrails.utils import temp as temp_module
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+ # optional imports
55
+ if TYPE_CHECKING:
56
+ import open3d as o3d
57
+
58
+ XArrayType = TypeVar("XArrayType", xr.Dataset, xr.DataArray)
59
+ MetDataType = TypeVar("MetDataType", "MetDataset", "MetDataArray")
60
+ DatasetType = TypeVar("DatasetType", xr.Dataset, "MetDataset")
61
+
62
+ COORD_DTYPE = np.float64
63
+
64
+
65
+ class MetBase(ABC, Generic[XArrayType]):
66
+ """Abstract class for building Meteorology Data handling classes.
67
+
68
+ All support here should be generic to work on xr.DataArray
69
+ and xr.Dataset.
70
+ """
71
+
72
+ __slots__ = ("cachestore", "data")
73
+
74
+ #: DataArray or Dataset
75
+ data: XArrayType
76
+
77
+ #: Cache datastore to use for :meth:`save` or :meth:`load`
78
+ cachestore: CacheStore | None
79
+
80
+ #: Default dimension order for DataArray or Dataset (x, y, z, t)
81
+ dim_order = (
82
+ "longitude",
83
+ "latitude",
84
+ "level",
85
+ "time",
86
+ )
87
+
88
+ @classmethod
89
+ def _from_fastpath(cls, data: XArrayType, cachestore: CacheStore | None = None) -> Self:
90
+ """Create new instance from consistent data.
91
+
92
+ This is a low-level method that bypasses the standard constructor in certain
93
+ special cases. It is intended for internal use only.
94
+
95
+ In essence, this method skips any validation from __init__ and directly sets
96
+ ``data`` and ``attrs``. This is useful when creating a new instance from an existing
97
+ instance the data has already been validated.
98
+ """
99
+ obj = cls.__new__(cls)
100
+ obj.data = data
101
+ obj.cachestore = cachestore
102
+ return obj
103
+
104
+ def __repr__(self) -> str:
105
+ data = getattr(self, "data", None)
106
+ return (
107
+ f"{self.__class__.__name__} with data:\n\n{data.__repr__() if data is not None else ''}"
108
+ )
109
+
110
+ def _repr_html_(self) -> str:
111
+ try:
112
+ return f"<b>{type(self).__name__}</b> with data:<br/ ><br/> {self.data._repr_html_()}"
113
+ except AttributeError:
114
+ return f"<b>{type(self).__name__}</b> without data"
115
+
116
+ def _validate_dim_contains_coords(self) -> None:
117
+ """Check that data contains four temporal-spatial coordinates.
118
+
119
+ Raises
120
+ ------
121
+ ValueError
122
+ If data does not contain all four coordinates (longitude, latitude, level, time).
123
+ """
124
+ missing = set(self.dim_order).difference(self.data.dims)
125
+ if not missing:
126
+ return
127
+
128
+ dim = sorted(missing)
129
+ msg = f"Meteorology data must contain dimension(s): {dim}."
130
+ if "level" in dim:
131
+ msg += (
132
+ " For single level data, set 'level' coordinate to constant -1 "
133
+ "using `ds = ds.expand_dims({'level': [-1]})`"
134
+ )
135
+ raise ValueError(msg)
136
+
137
+ def _validate_longitude(self) -> None:
138
+ """Check longitude bounds.
139
+
140
+ Assumes ``longitude`` dimension is already sorted.
141
+
142
+ Raises
143
+ ------
144
+ ValueError
145
+ If longitude values are not contained in the interval [-180, 180].
146
+ """
147
+ longitude = self.indexes["longitude"].to_numpy()
148
+ if longitude.dtype != COORD_DTYPE:
149
+ msg = f"Longitude values must have dtype {COORD_DTYPE}. Instantiate with 'copy=True'."
150
+ raise ValueError(msg)
151
+
152
+ if self.is_wrapped:
153
+ # Relax verification if the longitude has already been processed and wrapped
154
+ if longitude[-1] > 360.0:
155
+ raise ValueError(
156
+ "Longitude contains values > 360. Shift to WGS84 with "
157
+ "'data.assign_coords(longitude=(((data.longitude + 180) % 360) - 180))'"
158
+ )
159
+ if longitude[0] < -360.0:
160
+ raise ValueError(
161
+ "Longitude contains values < -360. Shift to WGS84 with "
162
+ "'data.assign_coords(longitude=(((data.longitude + 180) % 360) - 180))'"
163
+ )
164
+ return
165
+
166
+ # Strict!
167
+ if longitude[-1] > 180.0:
168
+ raise ValueError(
169
+ "Longitude contains values > 180. Shift to WGS84 with "
170
+ "'data.assign_coords(longitude=(((data.longitude + 180) % 360) - 180))'"
171
+ )
172
+ if longitude[0] < -180.0:
173
+ raise ValueError(
174
+ "Longitude contains values < -180. Shift to WGS84 with "
175
+ "'data.assign_coords(longitude=(((data.longitude + 180) % 360) - 180))'"
176
+ )
177
+
178
+ def _validate_latitude(self) -> None:
179
+ """Check latitude bounds.
180
+
181
+ Assumes ``latitude`` dimension is already sorted.
182
+
183
+ Raises
184
+ ------
185
+ ValueError
186
+ If latitude values are not contained in the interval [-90, 90].
187
+ """
188
+ latitude = self.indexes["latitude"].to_numpy()
189
+ if latitude.dtype != COORD_DTYPE:
190
+ msg = f"Latitude values must have dtype {COORD_DTYPE}. Instantiate with 'copy=True'."
191
+ raise ValueError(msg)
192
+
193
+ if latitude[0] < -90.0:
194
+ raise ValueError(
195
+ "Latitude contains values < -90 . "
196
+ "Latitude values must be contained in the interval [-90, 90]."
197
+ )
198
+ if latitude[-1] > 90.0:
199
+ raise ValueError(
200
+ "Latitude contains values > 90 . "
201
+ "Latitude values must be contained in the interval [-90, 90]."
202
+ )
203
+
204
+ def _validate_sorting(self) -> None:
205
+ """Check that all coordinates are sorted.
206
+
207
+ Raises
208
+ ------
209
+ ValueError
210
+ If one of the coordinates is not sorted or contains duplicate values.
211
+ """
212
+ indexes = self.indexes
213
+ for coord in self.dim_order:
214
+ arr = indexes[coord]
215
+ d = np.diff(arr)
216
+ zero = np.zeros((), dtype=d.dtype) # ensure same dtype
217
+
218
+ if np.any(d <= zero):
219
+ if np.any(d == zero):
220
+ msg = f"Coordinate '{coord}' contains duplicate values."
221
+ else:
222
+ msg = f"Coordinate '{coord}' not sorted."
223
+
224
+ msg += " Instantiate with 'copy=True'."
225
+ raise ValueError(msg)
226
+
227
+ def _validate_transpose(self) -> None:
228
+ """Check that data is transposed according to :attr:`dim_order`."""
229
+
230
+ def _check_da(da: xr.DataArray, key: Hashable | None = None) -> None:
231
+ if da.dims != self.dim_order:
232
+ if key is not None:
233
+ msg = (
234
+ f"Data dimension not transposed on variable '{key}'. "
235
+ "Instantiate with 'copy=True'."
236
+ )
237
+ else:
238
+ msg = "Data dimension not transposed. Instantiate with 'copy=True'."
239
+ raise ValueError(msg)
240
+
241
+ data = self.data
242
+ if isinstance(data, xr.DataArray):
243
+ _check_da(data)
244
+ return
245
+
246
+ for key, da in self.data.items():
247
+ _check_da(da, key)
248
+
249
+ def _validate_dims(self) -> None:
250
+ """Apply all validators."""
251
+ self._validate_dim_contains_coords()
252
+
253
+ # Apply this one first: validate_longitude and validate_latitude assume sorted
254
+ self._validate_sorting()
255
+ self._validate_longitude()
256
+ self._validate_latitude()
257
+ self._validate_transpose()
258
+ if self.data["level"].dtype != COORD_DTYPE:
259
+ msg = f"Level values must have dtype {COORD_DTYPE}. Instantiate with 'copy=True'."
260
+ raise ValueError(msg)
261
+
262
+ def _preprocess_dims(self, wrap_longitude: bool) -> None:
263
+ """Confirm DataArray or Dataset include required dimension in a consistent format.
264
+
265
+ Expects DataArray or Dataset to contain dimensions ``latitude`, ``longitude``, ``time``,
266
+ and ``level`` (in hPa/mbar).
267
+ Adds additional coordinate variables ``air_pressure`` and ``altitude`` coordinates
268
+ mapped to "level" dimension if "level" > 0.
269
+
270
+ Set ``level`` to -1 to signify single level.
271
+
272
+ .. versionchanged:: 0.40.0
273
+
274
+ All coordinate data (longitude, latitude, level) are promoted to ``float64``.
275
+ Auxiliary coordinates (altitude and air_pressure) are now cast to the same
276
+ dtype as the underlying grid data.
277
+
278
+ .. versionchanged:: 0.58.0
279
+
280
+ Duplicate dimension values are dropped, keeping the first occurrence.
281
+
282
+
283
+ Parameters
284
+ ----------
285
+ wrap_longitude : bool
286
+ If True, ensure longitude values cover the interval ``[-180, 180]``.
287
+
288
+ Raises
289
+ ------
290
+ ValueError
291
+ Raises if required dimension names are not found
292
+ """
293
+ self._validate_dim_contains_coords()
294
+
295
+ # Ensure spatial coordinates all have dtype COORD_DTYPE
296
+ indexes = self.indexes
297
+ for coord in ("longitude", "latitude", "level"):
298
+ arr = indexes[coord].to_numpy()
299
+ if arr.dtype != COORD_DTYPE:
300
+ self.data[coord] = arr.astype(COORD_DTYPE)
301
+
302
+ # Ensure time is np.datetime64[ns]
303
+ self.data["time"] = self.data["time"].astype("datetime64[ns]", copy=False)
304
+
305
+ # sortby to ensure each coordinate has ascending order
306
+ self.data = self.data.sortby(list(self.dim_order), ascending=True)
307
+
308
+ # Drop any duplicated dimension values
309
+ indexes = self.indexes
310
+ for coord in self.dim_order:
311
+ arr = indexes[coord]
312
+ d = np.diff(arr)
313
+ zero = np.zeros((), dtype=d.dtype) # ensure same dtype
314
+
315
+ if np.any(d == zero):
316
+ # Remove duplicates
317
+ filt = np.r_[True, d > zero] # prepend True keeps the first occurrence
318
+ self.data = self.data.isel({coord: filt})
319
+
320
+ if not self.is_wrapped:
321
+ # Ensure longitude is contained in interval [-180, 180)
322
+ # If longitude has value at 180, we might not want to shift it?
323
+ lon = self.indexes["longitude"].to_numpy()
324
+
325
+ # This longitude shifting can give rise to precision errors with float32
326
+ # Only shift if necessary
327
+ if np.any(lon >= 180.0) or np.any(lon < -180.0):
328
+ self.data = shift_longitude(self.data)
329
+ else:
330
+ self.data = self.data.sortby("longitude", ascending=True)
331
+
332
+ # wrap longitude, if requested
333
+ if wrap_longitude:
334
+ self.data = _wrap_longitude(self.data)
335
+
336
+ self._validate_longitude()
337
+ self._validate_latitude()
338
+
339
+ # transpose to have ordering (x, y, z, t, ...)
340
+ dim_order = [*self.dim_order, *(d for d in self.data.dims if d not in self.dim_order)]
341
+ self.data = self.data.transpose(*dim_order)
342
+
343
+ # single level data
344
+ if self.is_single_level:
345
+ # add level attributes to reflect surface level
346
+ level_attrs = self.data["level"].attrs
347
+ if not level_attrs:
348
+ level_attrs.update(units="", long_name="Single Level")
349
+ return
350
+
351
+ self.data = _add_vertical_coords(self.data)
352
+
353
+ @property
354
+ def hash(self) -> str:
355
+ """Generate a unique hash for this met instance.
356
+
357
+ Note this is not as robust as it could be since :func:`repr` cuts off.
358
+
359
+ Returns
360
+ -------
361
+ str
362
+ Unique hash for met instance (sha1)
363
+ """
364
+ _hash = repr(self.data)
365
+ return hashlib.sha1(bytes(_hash, "utf-8")).hexdigest()
366
+
367
+ @property
368
+ @abstractmethod
369
+ def size(self) -> int:
370
+ """Return the size of (each) array in underlying :attr:`data`.
371
+
372
+ Returns
373
+ -------
374
+ int
375
+ Total number of grid points in underlying data
376
+ """
377
+
378
+ @property
379
+ @abstractmethod
380
+ def shape(self) -> tuple[int, int, int, int]:
381
+ """Return the shape of the dimensions.
382
+
383
+ Returns
384
+ -------
385
+ tuple[int, int, int, int]
386
+ Shape of underlying data
387
+ """
388
+
389
+ @property
390
+ def coords(self) -> dict[str, np.ndarray]:
391
+ """Get coordinates of underlying :attr:`data` coordinates.
392
+
393
+ Only return non-dimension coordinates.
394
+
395
+ See:
396
+ http://xarray.pydata.org/en/stable/user-guide/data-structures.html#coordinates
397
+
398
+ Returns
399
+ -------
400
+ dict[str, np.ndarray]
401
+ Dictionary of coordinates
402
+ """
403
+ variables = self.indexes
404
+ return {
405
+ "longitude": variables["longitude"].to_numpy(),
406
+ "latitude": variables["latitude"].to_numpy(),
407
+ "level": variables["level"].to_numpy(),
408
+ "time": variables["time"].to_numpy(),
409
+ }
410
+
411
+ @property
412
+ def indexes(self) -> dict[Hashable, pd.Index]:
413
+ """Low level access to underlying :attr:`data` indexes.
414
+
415
+ This method is typically is faster for accessing coordinate indexes.
416
+
417
+ .. versionadded:: 0.25.2
418
+
419
+ Returns
420
+ -------
421
+ dict[Hashable, pd.Index]
422
+ Dictionary of indexes.
423
+
424
+ Examples
425
+ --------
426
+ >>> from pycontrails.datalib.ecmwf import ERA5
427
+ >>> times = (datetime(2022, 3, 1, 12), datetime(2022, 3, 1, 13))
428
+ >>> variables = "air_temperature", "specific_humidity"
429
+ >>> levels = [200, 300]
430
+ >>> era5 = ERA5(times, variables, levels)
431
+ >>> mds = era5.open_metdataset()
432
+ >>> mds.indexes["level"].to_numpy()
433
+ array([200., 300.])
434
+
435
+ >>> mda = mds["air_temperature"]
436
+ >>> mda.indexes["level"].to_numpy()
437
+ array([200., 300.])
438
+ """
439
+ return {k: v.index for k, v in self.data._indexes.items()} # type: ignore[attr-defined]
440
+
441
+ @property
442
+ def is_wrapped(self) -> bool:
443
+ """Check if the longitude dimension covers the closed interval ``[-180, 180]``.
444
+
445
+ Assumes the longitude dimension is sorted (this is established by the
446
+ :class:`MetDataset` or :class:`MetDataArray` constructor).
447
+
448
+ .. versionchanged:: 0.26.0
449
+
450
+ The previous implementation checked for the minimum and maximum longitude
451
+ dimension values to be duplicated. The current implementation only checks for
452
+ that the interval ``[-180, 180]`` is covered by the longitude dimension. The
453
+ :func:`pycontrails.physics.geo.advect_longitude` is designed for compatibility
454
+ with this convention.
455
+
456
+ Returns
457
+ -------
458
+ bool
459
+ True if longitude coordinates cover ``[-180, 180]``
460
+
461
+ See Also
462
+ --------
463
+ :func:`pycontrails.physics.geo.advect_longitude`
464
+ """
465
+ longitude = self.indexes["longitude"].to_numpy()
466
+ return _is_wrapped(longitude)
467
+
468
+ @property
469
+ def is_single_level(self) -> bool:
470
+ """Check if instance contains "single level" or "surface level" data.
471
+
472
+ This method checks if ``level`` dimension contains a single value equal
473
+ to -1, the pycontrails convention for surface only data.
474
+
475
+ Returns
476
+ -------
477
+ bool
478
+ If instance contains single level data.
479
+ """
480
+ level = self.indexes["level"].to_numpy()
481
+ return len(level) == 1 and level[0] == -1
482
+
483
+ @abstractmethod
484
+ def broadcast_coords(self, name: str) -> xr.DataArray:
485
+ """Broadcast coordinates along other dimensions.
486
+
487
+ Parameters
488
+ ----------
489
+ name : str
490
+ Coordinate/dimension name to broadcast.
491
+ Can be a dimension or non-dimension coordinates.
492
+
493
+ Returns
494
+ -------
495
+ xr.DataArray
496
+ DataArray of the coordinate broadcasted along all other dimensions.
497
+ The DataArray will have the same shape as the gridded data.
498
+ """
499
+
500
+ def _save(self, dataset: xr.Dataset, **kwargs: Any) -> list[str]:
501
+ """Save dataset to netcdf files named for the met hash and hour.
502
+
503
+ Does not yet save in parallel.
504
+
505
+ .. versionchanged:: 0.34.1
506
+
507
+ If :attr:`cachestore` is None, this method assigns it
508
+ to new :class:`DiskCacheStore`.
509
+
510
+ Parameters
511
+ ----------
512
+ dataset : xr.Dataset
513
+ Dataset to save
514
+ **kwargs
515
+ Keyword args passed directly on to :func:`xarray.save_mfdataset`
516
+
517
+ Returns
518
+ -------
519
+ list[str]
520
+ List of filenames saved
521
+ """
522
+ self.cachestore = self.cachestore or DiskCacheStore()
523
+
524
+ # group by hour and save one dataset for each hour to temp file
525
+ times, datasets = zip(*dataset.groupby("time", squeeze=False), strict=True)
526
+
527
+ # Open ExitStack to control temp_file context manager
528
+ with ExitStack() as stack:
529
+ xarray_temp_filenames = [stack.enter_context(temp_module.temp_file()) for _ in times]
530
+ xr.save_mfdataset(datasets, xarray_temp_filenames, **kwargs)
531
+
532
+ # set filenames by hash
533
+ filenames = [f"{self.hash}-{t_idx}.nc" for t_idx, _ in enumerate(times)]
534
+
535
+ # put each hourly file into cache
536
+ self.cachestore.put_multiple(xarray_temp_filenames, filenames)
537
+
538
+ return filenames
539
+
540
+ def __len__(self) -> int:
541
+ return self.data.__len__()
542
+
543
+ @property
544
+ def attrs(self) -> dict[str, Any]:
545
+ """Pass through to :attr:`self.data.attrs`."""
546
+ return self.data.attrs
547
+
548
+ def downselect(self, bbox: tuple[float, ...]) -> Self:
549
+ """Downselect met data within spatial bounding box.
550
+
551
+ Parameters
552
+ ----------
553
+ bbox : list[float]
554
+ List of coordinates defining a spatial bounding box in WGS84 coordinates.
555
+ For 2D queries, list is [west, south, east, north].
556
+ For 3D queries, list is [west, south, min-level, east, north, max-level]
557
+ with level defined in [:math:`hPa`].
558
+
559
+ Returns
560
+ -------
561
+ Self
562
+ Return downselected data
563
+ """
564
+ data = downselect(self.data, bbox)
565
+ return type(self)._from_fastpath(data, cachestore=self.cachestore)
566
+
567
+ @property
568
+ def is_zarr(self) -> bool:
569
+ """Check if underlying :attr:`data` is sourced from a Zarr group.
570
+
571
+ Implementation is very brittle, and may break as external libraries change.
572
+
573
+ Some ``dask`` intermediate artifact is cached when this is called. Typically,
574
+ subsequent calls to this method are much faster than the initial call.
575
+
576
+ .. versionadded:: 0.26.0
577
+
578
+ Returns
579
+ -------
580
+ bool
581
+ If ``data`` is based on a Zarr group.
582
+ """
583
+ return _is_zarr(self.data)
584
+
585
+ def downselect_met(
586
+ self,
587
+ met: MetDataType,
588
+ *,
589
+ longitude_buffer: tuple[float, float] = (0.0, 0.0),
590
+ latitude_buffer: tuple[float, float] = (0.0, 0.0),
591
+ level_buffer: tuple[float, float] = (0.0, 0.0),
592
+ time_buffer: tuple[np.timedelta64, np.timedelta64] = (
593
+ np.timedelta64(0, "h"),
594
+ np.timedelta64(0, "h"),
595
+ ),
596
+ ) -> MetDataType:
597
+ """Downselect ``met`` to encompass a spatiotemporal region of the data.
598
+
599
+ .. warning::
600
+
601
+ This method is analogous to :meth:`GeoVectorDataset.downselect_met`.
602
+ It does not change the instance data, but instead operates on the
603
+ ``met`` input. This method is different from :meth:`downselect` which
604
+ operates on the instance data.
605
+
606
+ .. versionchanged:: 0.54.5
607
+
608
+ Data is no longer copied when downselecting.
609
+
610
+ Parameters
611
+ ----------
612
+ met : MetDataset | MetDataArray
613
+ MetDataset or MetDataArray to downselect.
614
+ longitude_buffer : tuple[float, float], optional
615
+ Extend longitude domain past by ``longitude_buffer[0]`` on the low side
616
+ and ``longitude_buffer[1]`` on the high side.
617
+ Units must be the same as class coordinates.
618
+ Defaults to ``(0, 0)`` degrees.
619
+ latitude_buffer : tuple[float, float], optional
620
+ Extend latitude domain past by ``latitude_buffer[0]`` on the low side
621
+ and ``latitude_buffer[1]`` on the high side.
622
+ Units must be the same as class coordinates.
623
+ Defaults to ``(0, 0)`` degrees.
624
+ level_buffer : tuple[float, float], optional
625
+ Extend level domain past by ``level_buffer[0]`` on the low side
626
+ and ``level_buffer[1]`` on the high side.
627
+ Units must be the same as class coordinates.
628
+ Defaults to ``(0, 0)`` [:math:`hPa`].
629
+ time_buffer : tuple[np.timedelta64, np.timedelta64], optional
630
+ Extend time domain past by ``time_buffer[0]`` on the low side
631
+ and ``time_buffer[1]`` on the high side.
632
+ Units must be the same as class coordinates.
633
+ Defaults to ``(np.timedelta64(0, "h"), np.timedelta64(0, "h"))``.
634
+
635
+ Returns
636
+ -------
637
+ MetDataset | MetDataArray
638
+ Copy of downselected MetDataset or MetDataArray.
639
+ """
640
+ indexes = self.indexes
641
+ lon = indexes["longitude"].to_numpy()
642
+ lat = indexes["latitude"].to_numpy()
643
+ level = indexes["level"].to_numpy()
644
+ time = indexes["time"].to_numpy()
645
+
646
+ vector = vector_module.GeoVectorDataset(
647
+ longitude=[lon.min(), lon.max()],
648
+ latitude=[lat.min(), lat.max()],
649
+ level=[level.min(), level.max()],
650
+ time=[time.min(), time.max()],
651
+ )
652
+
653
+ return vector.downselect_met(
654
+ met,
655
+ longitude_buffer=longitude_buffer,
656
+ latitude_buffer=latitude_buffer,
657
+ level_buffer=level_buffer,
658
+ time_buffer=time_buffer,
659
+ )
660
+
661
+ def wrap_longitude(self) -> Self:
662
+ """Wrap longitude coordinates.
663
+
664
+ Returns
665
+ -------
666
+ Self
667
+ Copy of instance with wrapped longitude values.
668
+ Returns copy of data when longitude values are already wrapped
669
+ """
670
+ return type(self)._from_fastpath(_wrap_longitude(self.data), cachestore=self.cachestore)
671
+
672
+ def copy(self) -> Self:
673
+ """Create a shallow copy of the current class.
674
+
675
+ See :meth:`xarray.Dataset.copy` for reference.
676
+
677
+ Returns
678
+ -------
679
+ Self
680
+ Copy of the current class
681
+ """
682
+ return type(self)._from_fastpath(self.data.copy(), cachestore=self.cachestore)
683
+
684
+
685
+ class MetDataset(MetBase):
686
+ """Meteorological dataset with multiple variables.
687
+
688
+ Composition around :class:`xarray.Dataset` to enforce certain
689
+ variables and dimensions for internal usage
690
+
691
+ Parameters
692
+ ----------
693
+ data : xr.Dataset
694
+ :class:`xarray.Dataset` containing meteorological variables and coordinates
695
+ cachestore : :class:`CacheStore`, optional
696
+ Cache datastore for staging intermediates with :meth:`save`.
697
+ Defaults to None.
698
+ wrap_longitude : bool, optional
699
+ Wrap data along the longitude dimension. If True, duplicate and shift longitude
700
+ values (ie, ``-180 -> 180``) to ensure that the longitude dimension covers the entire
701
+ interval ``[-180, 180]``. Defaults to False.
702
+ copy : bool, optional
703
+ Copy data on construction. Defaults to True.
704
+ attrs : dict[str, Any], optional
705
+ Attributes to add to :attr:`data.attrs`. Defaults to None. Generally, pycontrails
706
+ :class:`pycontrails.core.models.Models` may use the following attributes:
707
+
708
+ - ``provider``: Name of the data provider (e.g. ``"ECMWF"``).
709
+ - ``dataset``: Name of the dataset (e.g. ``"ERA5"``).
710
+ - ``product``: Name of the product type (e.g. ``"reanalysis"``).
711
+
712
+ **attrs_kwargs : Any
713
+ Keyword arguments to add to :attr:`data.attrs`. Defaults to None.
714
+
715
+ Examples
716
+ --------
717
+ >>> import numpy as np
718
+ >>> import pandas as pd
719
+ >>> import xarray as xr
720
+ >>> from pycontrails.datalib.ecmwf import ERA5
721
+
722
+ >>> time = ("2022-03-01T00", "2022-03-01T02")
723
+ >>> variables = ["air_temperature", "specific_humidity"]
724
+ >>> pressure_levels = [200, 250, 300]
725
+ >>> era5 = ERA5(time, variables, pressure_levels)
726
+
727
+ >>> # Open directly as `MetDataset`
728
+ >>> met = era5.open_metdataset()
729
+ >>> # Use `data` attribute to access `xarray` object
730
+ >>> assert isinstance(met.data, xr.Dataset)
731
+
732
+ >>> # Alternatively, open with `xarray` and cast to `MetDataset`
733
+ >>> ds = xr.open_mfdataset(era5._cachepaths)
734
+ >>> met = MetDataset(ds)
735
+
736
+ >>> # Access sub-`DataArrays`
737
+ >>> mda = met["t"] # `MetDataArray` instance, needed for interpolation operations
738
+ >>> da = mda.data # Underlying `xarray` object
739
+
740
+ >>> # Check out a few values
741
+ >>> da[5:8, 5:8, 1, 1].values
742
+ array([[224.08959005, 224.41374427, 224.75945349],
743
+ [224.09456429, 224.42037658, 224.76525676],
744
+ [224.10036756, 224.42617985, 224.77106004]])
745
+
746
+ >>> # Mean temperature over entire array
747
+ >>> da.mean().load().item()
748
+ 223.5083
749
+ """
750
+
751
+ __slots__ = ()
752
+
753
+ data: xr.Dataset
754
+
755
+ def __init__(
756
+ self,
757
+ data: xr.Dataset,
758
+ cachestore: CacheStore | None = None,
759
+ wrap_longitude: bool = False,
760
+ copy: bool = True,
761
+ attrs: dict[str, Any] | None = None,
762
+ **attrs_kwargs: Any,
763
+ ) -> None:
764
+ self.cachestore = cachestore
765
+
766
+ data.attrs.update(attrs or {}, **attrs_kwargs)
767
+
768
+ # if input is already a Dataset, copy into data
769
+ if not isinstance(data, xr.Dataset):
770
+ raise TypeError("Input 'data' must be an xarray Dataset")
771
+
772
+ # copy Dataset into data
773
+ if copy:
774
+ self.data = data.copy()
775
+ self._preprocess_dims(wrap_longitude)
776
+
777
+ else:
778
+ if wrap_longitude:
779
+ raise ValueError("Set 'copy=True' when using 'wrap_longitude=True'.")
780
+ self.data = data
781
+ self._validate_dims()
782
+ if not self.is_single_level:
783
+ self.data = _add_vertical_coords(self.data)
784
+
785
+ def __getitem__(self, key: Hashable) -> MetDataArray:
786
+ """Return DataArray of variable ``key`` cast to a :class:`MetDataArray` object.
787
+
788
+ Parameters
789
+ ----------
790
+ key : Hashable
791
+ Variable name
792
+
793
+ Returns
794
+ -------
795
+ MetDataArray
796
+ MetDataArray instance associated to :attr:`data` variable `key`
797
+
798
+ Raises
799
+ ------
800
+ KeyError
801
+ If ``key`` not found in :attr:`data`
802
+ """
803
+ try:
804
+ da = self.data[key]
805
+ except KeyError as e:
806
+ raise KeyError(
807
+ f"Variable {key} not found. Available variables: {', '.join(self.data.data_vars)}. "
808
+ "To get items (e.g. 'time' or 'level') from underlying xr.Dataset object, "
809
+ "use the 'data' attribute."
810
+ ) from e
811
+ return MetDataArray._from_fastpath(da)
812
+
813
+ def get(self, key: str, default_value: Any = None) -> Any:
814
+ """Shortcut to :meth:`xarray.Dataset.get` method.
815
+
816
+ Parameters
817
+ ----------
818
+ key : str
819
+ Key to get from :attr:`data`
820
+ default_value : Any, optional
821
+ Return `default_value` if `key` not in :attr:`data`, by default `None`
822
+
823
+ Returns
824
+ -------
825
+ Any
826
+ Values returned from :attr:`data.get(key, default_value)`
827
+ """
828
+ return self.data.get(key, default_value)
829
+
830
+ def __setitem__(
831
+ self,
832
+ key: Hashable | list[Hashable] | Mapping,
833
+ value: Any,
834
+ ) -> None:
835
+ """Shortcut to set data variable on :attr:`data`.
836
+
837
+ Warns if ``key`` is already present in dataset.
838
+
839
+ Parameters
840
+ ----------
841
+ key : Hashable | list[Hashable] | Mapping
842
+ Variable name
843
+ value : Any
844
+ Value to set to variable names
845
+
846
+ See Also
847
+ --------
848
+ - :class:`xarray.Dataset.__setitem__`
849
+ """
850
+
851
+ # pull data of MetDataArray value
852
+ if isinstance(value, MetDataArray):
853
+ value = value.data
854
+
855
+ # warn if key is already in Dataset
856
+ override_keys: list[Hashable] = []
857
+
858
+ if isinstance(key, Hashable):
859
+ if key in self:
860
+ override_keys = [key]
861
+
862
+ # xarray.core.utils.is_dict_like
863
+ # https://github.com/pydata/xarray/blob/4cae8d0ec04195291b2315b1f21d846c2bad61ff/xarray/core/utils.py#L244
864
+ elif xr.core.utils.is_dict_like(key) or isinstance(key, list):
865
+ override_keys = [k for k in key if k in self.data]
866
+
867
+ if override_keys:
868
+ warnings.warn(
869
+ f"Overwriting data in keys `{override_keys}`. "
870
+ "Use `.update(...)` to suppress warning."
871
+ )
872
+
873
+ self.data.__setitem__(key, value)
874
+
875
+ def update(self, other: MutableMapping | None = None, **kwargs: Any) -> None:
876
+ """Shortcut to :meth:`data.update`.
877
+
878
+ See :meth:`xarray.Dataset.update` for reference.
879
+
880
+ Parameters
881
+ ----------
882
+ other : MutableMapping
883
+ Variables with which to update this dataset
884
+ **kwargs : Any
885
+ Variables defined by keyword arguments. If a variable exists both in
886
+ ``other`` and as a keyword argument, the keyword argument takes
887
+ precedence.
888
+
889
+ See Also
890
+ --------
891
+ xarray.Dataset.update
892
+ """
893
+ other = other or {}
894
+ other.update(kwargs)
895
+
896
+ # pull data of MetDataArray value
897
+ for k, v in other.items():
898
+ if isinstance(v, MetDataArray):
899
+ other[k] = v.data
900
+
901
+ self.data.update(other)
902
+
903
+ def __iter__(self) -> Iterator[str]:
904
+ """Allow for the use as "key" in self.met, where "key" is a data variable."""
905
+ # From the typing perspective, `iter(self.data)`` returns Hashables (not
906
+ # necessarily strs). In everything we do, we use str variables.
907
+ # If we decide to extend this to support Hashable, we'll also want to
908
+ # change VectorDataset -- the underlying :attr:`data` should then be
909
+ # changed to `dict[Hashable, np.ndarray]`.
910
+ for key in self.data:
911
+ yield str(key)
912
+
913
+ def __contains__(self, key: Hashable) -> bool:
914
+ """Check if key ``key`` is in :attr:`data`.
915
+
916
+ Parameters
917
+ ----------
918
+ key : Hashable
919
+ Key to check
920
+
921
+ Returns
922
+ -------
923
+ bool
924
+ True if ``key`` is in :attr:`data`, False otherwise
925
+ """
926
+ return key in self.data
927
+
928
+ @property
929
+ @override
930
+ def shape(self) -> tuple[int, int, int, int]:
931
+ sizes = self.data.sizes
932
+ return sizes["longitude"], sizes["latitude"], sizes["level"], sizes["time"]
933
+
934
+ @property
935
+ @override
936
+ def size(self) -> int:
937
+ return np.prod(self.shape).item()
938
+
939
+ def ensure_vars(
940
+ self,
941
+ vars: MetVariable | str | Sequence[MetVariable | str | Sequence[MetVariable]],
942
+ raise_error: bool = True,
943
+ ) -> list[str]:
944
+ """Ensure variables exist in xr.Dataset.
945
+
946
+ Parameters
947
+ ----------
948
+ vars : MetVariable | str | Sequence[MetVariable | str | list[MetVariable]]
949
+ List of MetVariable (or string key), or individual MetVariable (or string key).
950
+ If ``vars`` contains an element with a list[MetVariable], then
951
+ only one variable in the list must be present in dataset.
952
+ raise_error : bool, optional
953
+ Raise KeyError if data does not contain variables.
954
+ Defaults to True.
955
+
956
+ Returns
957
+ -------
958
+ list[str]
959
+ List of met keys verified in :class:`MetDataset`.
960
+ Returns an empty list if any :class:`MetVariable` is missing.
961
+
962
+ Raises
963
+ ------
964
+ KeyError
965
+ Raises when dataset does not contain variable in ``vars``
966
+ """
967
+ if isinstance(vars, MetVariable | str):
968
+ vars = (vars,)
969
+
970
+ met_keys: list[str] = []
971
+ for variable in vars:
972
+ met_key: str | None = None
973
+
974
+ # input is a MetVariable or str
975
+ if isinstance(variable, MetVariable):
976
+ if (key := variable.standard_name) in self:
977
+ met_key = key
978
+ elif isinstance(variable, str):
979
+ if variable in self:
980
+ met_key = variable
981
+
982
+ # otherwise, assume input is an sequence
983
+ # Sequence[MetVariable] means that any variable in list will work
984
+ else:
985
+ for v in variable:
986
+ if (key := v.standard_name) in self:
987
+ met_key = key
988
+ break
989
+
990
+ if met_key is None:
991
+ if not raise_error:
992
+ return []
993
+
994
+ if isinstance(variable, MetVariable):
995
+ raise KeyError(f"Dataset does not contain variable `{variable.standard_name}`")
996
+ if isinstance(variable, list):
997
+ raise KeyError(
998
+ "Dataset does not contain one of variables "
999
+ f"`{[v.standard_name for v in variable]}`"
1000
+ )
1001
+ raise KeyError(f"Dataset does not contain variable `{variable}`")
1002
+
1003
+ met_keys.append(met_key)
1004
+
1005
+ return met_keys
1006
+
1007
+ def save(self, **kwargs: Any) -> list[str]:
1008
+ """Save intermediate to :attr:`cachestore` as netcdf.
1009
+
1010
+ Load and restore using :meth:`load`.
1011
+
1012
+ Parameters
1013
+ ----------
1014
+ **kwargs : Any
1015
+ Keyword arguments passed directly to :meth:`xarray.Dataset.to_netcdf`
1016
+
1017
+ Returns
1018
+ -------
1019
+ list[str]
1020
+ Returns filenames saved
1021
+ """
1022
+ return self._save(self.data, **kwargs)
1023
+
1024
+ @classmethod
1025
+ def load(
1026
+ cls,
1027
+ hash: str,
1028
+ cachestore: CacheStore | None = None,
1029
+ chunks: dict[str, int] | None = None,
1030
+ ) -> Self:
1031
+ """Load saved intermediate from :attr:`cachestore`.
1032
+
1033
+ Parameters
1034
+ ----------
1035
+ hash : str
1036
+ Saved hash to load.
1037
+ cachestore : :class:`CacheStore`, optional
1038
+ Cache datastore to use for sourcing files.
1039
+ Defaults to DiskCacheStore.
1040
+ chunks : dict[str: int], optional
1041
+ Chunks kwarg passed to :func:`xarray.open_mfdataset()` when opening files.
1042
+
1043
+ Returns
1044
+ -------
1045
+ Self
1046
+ New MetDataArray with loaded data.
1047
+ """
1048
+ cachestore = cachestore or DiskCacheStore()
1049
+ chunks = chunks or {}
1050
+ data = _load(hash, cachestore, chunks)
1051
+ return cls(data)
1052
+
1053
+ @override
1054
+ def broadcast_coords(self, name: str) -> xr.DataArray:
1055
+ da = xr.ones_like(self.data[next(iter(self.data.keys()))]) * self.data[name]
1056
+ da.name = name
1057
+
1058
+ return da
1059
+
1060
+ def to_vector(self, transfer_attrs: bool = True) -> vector_module.GeoVectorDataset:
1061
+ """Convert a :class:`MetDataset` to a :class:`GeoVectorDataset` by raveling data.
1062
+
1063
+ If :attr:`data` is lazy, it will be loaded.
1064
+
1065
+ Parameters
1066
+ ----------
1067
+ transfer_attrs : bool, optional
1068
+ Transfer attributes from :attr:`data` to output :class:`GeoVectorDataset`.
1069
+ By default, True, meaning that attributes are transferred.
1070
+
1071
+ Returns
1072
+ -------
1073
+ GeoVectorDataset
1074
+ Converted :class:`GeoVectorDataset`. The variables on the returned instance
1075
+ include all of those on the input instance, plus the four core spatial temporal
1076
+ variables.
1077
+
1078
+ Examples
1079
+ --------
1080
+ >>> from pycontrails.datalib.ecmwf import ERA5
1081
+ >>> times = "2022-03-01", "2022-03-01T01"
1082
+ >>> variables = ["air_temperature", "specific_humidity"]
1083
+ >>> levels = [250, 200]
1084
+ >>> era5 = ERA5(time=times, variables=variables, pressure_levels=levels)
1085
+ >>> met = era5.open_metdataset()
1086
+ >>> met.to_vector(transfer_attrs=False)
1087
+ GeoVectorDataset [6 keys x 4152960 length, 0 attributes]
1088
+ Keys: longitude, latitude, level, time, air_temperature, ..., specific_humidity
1089
+ Attributes:
1090
+ time [2022-03-01 00:00:00, 2022-03-01 01:00:00]
1091
+ longitude [-180.0, 179.75]
1092
+ latitude [-90.0, 90.0]
1093
+ altitude [10362.8, 11783.9]
1094
+
1095
+ """
1096
+ coords_keys = self.data.dims
1097
+ indexes = self.indexes
1098
+ coords_vals = [indexes[key].values for key in coords_keys]
1099
+ coords_meshes = np.meshgrid(*coords_vals, indexing="ij")
1100
+ raveled_coords = (mesh.ravel() for mesh in coords_meshes)
1101
+ data = dict(zip(coords_keys, raveled_coords, strict=True))
1102
+
1103
+ out = vector_module.GeoVectorDataset(data, copy=False)
1104
+ for key, da in self.data.items():
1105
+ # The call to .values here will load the data if it is lazy
1106
+ out[key] = da.values.ravel() # type: ignore[index]
1107
+
1108
+ if transfer_attrs:
1109
+ out.attrs.update(self.attrs)
1110
+
1111
+ return out
1112
+
1113
+ def _get_pycontrails_attr_template(
1114
+ self,
1115
+ name: str,
1116
+ supported: tuple[str, ...],
1117
+ examples: dict[str, str],
1118
+ ) -> str:
1119
+ """Look up an attribute with a custom error message."""
1120
+ try:
1121
+ out = self.attrs[name]
1122
+ except KeyError as e:
1123
+ msg = f"Specify '{name}' attribute on underlying dataset."
1124
+
1125
+ for i, (k, v) in enumerate(examples.items()):
1126
+ if i == 0:
1127
+ msg = f"{msg} For example, set attrs['{name}'] = '{k}' for {v}."
1128
+ else:
1129
+ msg = f"{msg} Set attrs['{name}'] = '{k}' for {v}."
1130
+ raise KeyError(msg) from e
1131
+
1132
+ if out not in supported:
1133
+ warnings.warn(
1134
+ f"Unknown {name} '{out}'. Data may not be processed correctly. "
1135
+ f"Known {name}s are {supported}. Contact the pycontrails "
1136
+ "developers if you believe this is an error."
1137
+ )
1138
+
1139
+ return out
1140
+
1141
+ @property
1142
+ def provider_attr(self) -> str:
1143
+ """Look up the ``"provider"`` attribute with a custom error message.
1144
+
1145
+ Returns
1146
+ -------
1147
+ str
1148
+ Provider of the data. If not one of ``"ECMWF"`` or ``"NCEP"``,
1149
+ a warning is issued.
1150
+ """
1151
+ supported = ("ECMWF", "NCEP")
1152
+ examples = {"ECMWF": "data provided by ECMWF", "NCEP": "GFS data"}
1153
+ return self._get_pycontrails_attr_template("provider", supported, examples)
1154
+
1155
+ @property
1156
+ def dataset_attr(self) -> str:
1157
+ """Look up the ``"dataset"`` attribute with a custom error message.
1158
+
1159
+ Returns
1160
+ -------
1161
+ str
1162
+ Dataset of the data. If not one of ``"ERA5"``, ``"HRES"``, ``"IFS"``,
1163
+ or ``"GFS"``, a warning is issued.
1164
+ """
1165
+ supported = ("ERA5", "HRES", "IFS", "GFS")
1166
+ examples = {
1167
+ "ERA5": "ECMWF ERA5 reanalysis data",
1168
+ "HRES": "ECMWF HRES forecast data",
1169
+ "GFS": "NCEP GFS forecast data",
1170
+ }
1171
+ return self._get_pycontrails_attr_template("dataset", supported, examples)
1172
+
1173
+ @property
1174
+ def product_attr(self) -> str:
1175
+ """Look up the ``"product"`` attribute with a custom error message.
1176
+
1177
+ Returns
1178
+ -------
1179
+ str
1180
+ Product of the data. If not one of ``"forecast"``, ``"ensemble"``,
1181
+ or ``"reanalysis"``, a warning is issued.
1182
+
1183
+ """
1184
+ supported = ("reanalysis", "forecast", "ensemble")
1185
+ examples = {
1186
+ "reanalysis": "ECMWF ERA5 reanalysis data",
1187
+ "ensemble": "ECMWF ERA5 ensemble member data",
1188
+ }
1189
+ return self._get_pycontrails_attr_template("product", supported, examples)
1190
+
1191
+ @overload
1192
+ def standardize_variables(
1193
+ self, variables: Iterable[MetVariable], inplace: Literal[False] = ...
1194
+ ) -> Self: ...
1195
+
1196
+ @overload
1197
+ def standardize_variables(
1198
+ self, variables: Iterable[MetVariable], inplace: Literal[True]
1199
+ ) -> None: ...
1200
+
1201
+ def standardize_variables(
1202
+ self, variables: Iterable[MetVariable], inplace: bool = False
1203
+ ) -> Self | None:
1204
+ """Standardize variable names.
1205
+
1206
+ .. versionchanged:: 0.54.7
1207
+
1208
+ By default, this method returns a new :class:`MetDataset` instead
1209
+ of renaming in place. To retain the old behavior, set ``inplace=True``.
1210
+ The ``inplace`` behavior is deprecated and will be removed in a future release.
1211
+
1212
+ Parameters
1213
+ ----------
1214
+ variables : Iterable[MetVariable]
1215
+ Data source variables
1216
+ inplace : bool, optional
1217
+ If True, rename variables in place. Otherwise, return a new
1218
+ :class:`MetDataset` with renamed variables.
1219
+
1220
+ See Also
1221
+ --------
1222
+ :func:`standardize_variables`
1223
+ """
1224
+ data_renamed = standardize_variables(self.data, variables)
1225
+
1226
+ if inplace:
1227
+ warnings.warn(
1228
+ "The inplace behavior is deprecated and will be removed in a future release. ",
1229
+ DeprecationWarning,
1230
+ stacklevel=2,
1231
+ )
1232
+ self.data = data_renamed
1233
+ return None
1234
+
1235
+ return type(self)._from_fastpath(data_renamed, cachestore=self.cachestore)
1236
+
1237
+ @classmethod
1238
+ def from_coords(
1239
+ cls,
1240
+ longitude: npt.ArrayLike | float,
1241
+ latitude: npt.ArrayLike | float,
1242
+ level: npt.ArrayLike | float,
1243
+ time: npt.ArrayLike | np.datetime64,
1244
+ ) -> Self:
1245
+ r"""Create a :class:`MetDataset` containing a coordinate skeleton from coordinate arrays.
1246
+
1247
+ Parameters
1248
+ ----------
1249
+ longitude, latitude : npt.ArrayLike | float
1250
+ Horizontal coordinates, in [:math:`\deg`]
1251
+ level : npt.ArrayLike | float
1252
+ Vertical coordinate, in [:math:`hPa`]
1253
+ time: npt.ArrayLike | np.datetime64,
1254
+ Temporal coordinates, in [:math:`UTC`]. Will be sorted.
1255
+
1256
+ Returns
1257
+ -------
1258
+ Self
1259
+ MetDataset with no variables.
1260
+
1261
+ Examples
1262
+ --------
1263
+ >>> # Create skeleton MetDataset
1264
+ >>> longitude = np.arange(0, 10, 0.5)
1265
+ >>> latitude = np.arange(0, 10, 0.5)
1266
+ >>> level = [250, 300]
1267
+ >>> time = np.datetime64("2019-01-01")
1268
+ >>> met = MetDataset.from_coords(longitude, latitude, level, time)
1269
+ >>> met
1270
+ MetDataset with data:
1271
+ <xarray.Dataset> Size: 360B
1272
+ Dimensions: (longitude: 20, latitude: 20, level: 2, time: 1)
1273
+ Coordinates:
1274
+ * longitude (longitude) float64 160B 0.0 0.5 1.0 1.5 ... 8.0 8.5 9.0 9.5
1275
+ * latitude (latitude) float64 160B 0.0 0.5 1.0 1.5 ... 8.0 8.5 9.0 9.5
1276
+ * level (level) float64 16B 250.0 300.0
1277
+ * time (time) datetime64[ns] 8B 2019-01-01
1278
+ air_pressure (level) float32 8B 2.5e+04 3e+04
1279
+ altitude (level) float32 8B 1.036e+04 9.164e+03
1280
+ Data variables:
1281
+ *empty*
1282
+
1283
+ >>> met.shape
1284
+ (20, 20, 2, 1)
1285
+
1286
+ >>> met.size
1287
+ 800
1288
+
1289
+ >>> # Fill it up with some constant data
1290
+ >>> met["temperature"] = xr.DataArray(np.full(met.shape, 234.5), coords=met.coords)
1291
+ >>> met["humidity"] = xr.DataArray(np.full(met.shape, 0.5), coords=met.coords)
1292
+ >>> met
1293
+ MetDataset with data:
1294
+ <xarray.Dataset> Size: 13kB
1295
+ Dimensions: (longitude: 20, latitude: 20, level: 2, time: 1)
1296
+ Coordinates:
1297
+ * longitude (longitude) float64 160B 0.0 0.5 1.0 1.5 ... 8.0 8.5 9.0 9.5
1298
+ * latitude (latitude) float64 160B 0.0 0.5 1.0 1.5 ... 8.0 8.5 9.0 9.5
1299
+ * level (level) float64 16B 250.0 300.0
1300
+ * time (time) datetime64[ns] 8B 2019-01-01
1301
+ air_pressure (level) float32 8B 2.5e+04 3e+04
1302
+ altitude (level) float32 8B 1.036e+04 9.164e+03
1303
+ Data variables:
1304
+ temperature (longitude, latitude, level, time) float64 6kB 234.5 ... 234.5
1305
+ humidity (longitude, latitude, level, time) float64 6kB 0.5 0.5 ... 0.5
1306
+
1307
+ >>> # Convert to a GeoVectorDataset
1308
+ >>> vector = met.to_vector()
1309
+ >>> vector.dataframe.head()
1310
+ longitude latitude level time temperature humidity
1311
+ 0 0.0 0.0 250.0 2019-01-01 234.5 0.5
1312
+ 1 0.0 0.0 300.0 2019-01-01 234.5 0.5
1313
+ 2 0.0 0.5 250.0 2019-01-01 234.5 0.5
1314
+ 3 0.0 0.5 300.0 2019-01-01 234.5 0.5
1315
+ 4 0.0 1.0 250.0 2019-01-01 234.5 0.5
1316
+ """
1317
+ input_data = {
1318
+ "longitude": longitude,
1319
+ "latitude": latitude,
1320
+ "level": level,
1321
+ "time": time,
1322
+ }
1323
+
1324
+ # clean up input into coords
1325
+ coords: dict[str, np.ndarray] = {}
1326
+ for key, val in input_data.items():
1327
+ dtype = "datetime64[ns]" if key == "time" else COORD_DTYPE
1328
+ arr: np.ndarray = np.asarray(val, dtype=dtype)
1329
+
1330
+ if arr.ndim == 0:
1331
+ arr = arr.reshape(1)
1332
+ elif arr.ndim > 1:
1333
+ raise ValueError(f"{key} has too many dimensions")
1334
+
1335
+ arr = np.sort(arr)
1336
+ if arr.size == 0:
1337
+ raise ValueError(f"Coordinate {key} must be nonempty.")
1338
+
1339
+ coords[key] = arr
1340
+
1341
+ return cls(xr.Dataset({}, coords=coords))
1342
+
1343
+ @classmethod
1344
+ def from_zarr(cls, store: Any, **kwargs: Any) -> Self:
1345
+ """Create a :class:`MetDataset` from a path to a Zarr store.
1346
+
1347
+ Parameters
1348
+ ----------
1349
+ store : Any
1350
+ Path to Zarr store. Passed into :func:`xarray.open_zarr`.
1351
+ **kwargs : Any
1352
+ Other keyword only arguments passed into :func:`xarray.open_zarr`.
1353
+
1354
+ Returns
1355
+ -------
1356
+ Self
1357
+ MetDataset with data from Zarr store.
1358
+ """
1359
+ kwargs.setdefault("storage_options", {"read_only": True})
1360
+ ds = xr.open_zarr(store, **kwargs)
1361
+ return cls(ds)
1362
+
1363
+
1364
+ class MetDataArray(MetBase):
1365
+ """Meteorological DataArray of single variable.
1366
+
1367
+ Wrapper around :class:`xarray.DataArray` to enforce certain
1368
+ variables and dimensions for internal usage.
1369
+
1370
+ .. versionchanged:: 0.54.4
1371
+
1372
+ Remove ``validate`` parameter. Validation is now always performed.
1373
+
1374
+ Parameters
1375
+ ----------
1376
+ data : ArrayLike
1377
+ xr.DataArray or other array-like data source.
1378
+ When array-like input is provided, input ``**kwargs`` passed directly to
1379
+ xr.DataArray constructor.
1380
+ cachestore : :class:`CacheStore`, optional
1381
+ Cache datastore for staging intermediates with :meth:`save`.
1382
+ Defaults to DiskCacheStore.
1383
+ wrap_longitude : bool, optional
1384
+ Wrap data along the longitude dimension. If True, duplicate and shift longitude
1385
+ values (ie, -180 -> 180) to ensure that the longitude dimension covers the entire
1386
+ interval ``[-180, 180]``. Defaults to False.
1387
+ copy : bool, optional
1388
+ Copy `data` parameter on construction, by default `True`. If `data` is lazy-loaded
1389
+ via `dask`, this parameter has no effect. If `data` is already loaded into memory,
1390
+ a copy of the data (rather than a view) may be created if `True`.
1391
+ name : Hashable, optional
1392
+ Name of the data variable. If not specified, the name will be set to "met".
1393
+
1394
+ Examples
1395
+ --------
1396
+ >>> import numpy as np
1397
+ >>> import xarray as xr
1398
+ >>> rng = np.random.default_rng(seed=456)
1399
+
1400
+ >>> # Cook up random xarray object
1401
+ >>> coords = {
1402
+ ... "longitude": np.arange(-20, 20),
1403
+ ... "latitude": np.arange(-30, 30),
1404
+ ... "level": [220, 240, 260, 280],
1405
+ ... "time": [np.datetime64("2021-08-01T12", "ns"), np.datetime64("2021-08-01T16", "ns")]
1406
+ ... }
1407
+ >>> da = xr.DataArray(rng.random((40, 60, 4, 2)), dims=coords.keys(), coords=coords)
1408
+
1409
+ >>> # Cast to `MetDataArray` in order to interpolate
1410
+ >>> from pycontrails import MetDataArray
1411
+ >>> mda = MetDataArray(da)
1412
+ >>> mda.interpolate(-11.4, 5.7, 234, np.datetime64("2021-08-01T13"))
1413
+ array([0.52358215])
1414
+
1415
+ >>> mda.interpolate(-11.4, 5.7, 234, np.datetime64("2021-08-01T13"), method='nearest')
1416
+ array([0.4188465])
1417
+
1418
+ >>> da.sel(longitude=-11, latitude=6, level=240, time=np.datetime64("2021-08-01T12")).item()
1419
+ 0.41884649899766946
1420
+ """
1421
+
1422
+ __slots__ = ()
1423
+
1424
+ data: xr.DataArray
1425
+
1426
+ def __init__(
1427
+ self,
1428
+ data: xr.DataArray,
1429
+ cachestore: CacheStore | None = None,
1430
+ wrap_longitude: bool = False,
1431
+ copy: bool = True,
1432
+ name: Hashable | None = None,
1433
+ ) -> None:
1434
+ self.cachestore = cachestore
1435
+
1436
+ if copy:
1437
+ self.data = data.copy()
1438
+ self._preprocess_dims(wrap_longitude)
1439
+ elif wrap_longitude:
1440
+ raise ValueError("Set 'copy=True' when using 'wrap_longitude=True'.")
1441
+ else:
1442
+ self.data = data
1443
+ self._validate_dims()
1444
+
1445
+ # Priority: name > data.name > "met"
1446
+ self.data.name = name or self.data.name or "met"
1447
+
1448
+ @property
1449
+ def values(self) -> np.ndarray:
1450
+ """Return underlying numpy array.
1451
+
1452
+ This methods loads :attr:`data` if it is not already in memory.
1453
+
1454
+ Returns
1455
+ -------
1456
+ np.ndarray
1457
+ Underlying numpy array
1458
+
1459
+ See Also
1460
+ --------
1461
+ :meth:`xarray.Dataset.load`
1462
+ :meth:`xarray.DataArray.load`
1463
+
1464
+ """
1465
+ if not self.in_memory:
1466
+ self._check_memory("Extracting numpy array from")
1467
+ self.data.load()
1468
+
1469
+ return self.data.values
1470
+
1471
+ @property
1472
+ def name(self) -> Hashable:
1473
+ """Return the DataArray name.
1474
+
1475
+ Returns
1476
+ -------
1477
+ Hashable
1478
+ DataArray name
1479
+ """
1480
+ return self.data.name
1481
+
1482
+ @property
1483
+ def binary(self) -> bool:
1484
+ """Determine if all data is a binary value (0, 1).
1485
+
1486
+ Returns
1487
+ -------
1488
+ bool
1489
+ True if all data values are binary value (0, 1)
1490
+ """
1491
+ return np.array_equal(self.data, self.data.astype(bool))
1492
+
1493
+ @property
1494
+ @override
1495
+ def size(self) -> int:
1496
+ return self.data.size
1497
+
1498
+ @property
1499
+ @override
1500
+ def shape(self) -> tuple[int, int, int, int]:
1501
+ # https://github.com/python/mypy/issues/1178
1502
+ return typing.cast(tuple[int, int, int, int], self.data.shape)
1503
+
1504
+ @property
1505
+ def in_memory(self) -> bool:
1506
+ """Check if underlying :attr:`data` is loaded into memory.
1507
+
1508
+ This method uses protected attributes of underlying `xarray` objects, and may be subject
1509
+ to deprecation.
1510
+
1511
+ .. versionchanged:: 0.26.0
1512
+
1513
+ Rename from ``is_loaded`` to ``in_memory``.
1514
+
1515
+ Returns
1516
+ -------
1517
+ bool
1518
+ If underlying data exists as an `np.ndarray` in memory.
1519
+ """
1520
+ return self.data._in_memory
1521
+
1522
+ @overload
1523
+ def interpolate(
1524
+ self,
1525
+ longitude: float | npt.NDArray[np.floating],
1526
+ latitude: float | npt.NDArray[np.floating],
1527
+ level: float | npt.NDArray[np.floating],
1528
+ time: np.datetime64 | npt.NDArray[np.datetime64],
1529
+ *,
1530
+ method: str = ...,
1531
+ bounds_error: bool = ...,
1532
+ fill_value: float | np.float64 | None = ...,
1533
+ localize: bool = ...,
1534
+ lowmem: bool = ...,
1535
+ indices: interpolation.RGIArtifacts | None = ...,
1536
+ return_indices: Literal[False] = ...,
1537
+ ) -> npt.NDArray[np.floating]: ...
1538
+
1539
+ @overload
1540
+ def interpolate(
1541
+ self,
1542
+ longitude: float | npt.NDArray[np.floating],
1543
+ latitude: float | npt.NDArray[np.floating],
1544
+ level: float | npt.NDArray[np.floating],
1545
+ time: np.datetime64 | npt.NDArray[np.datetime64],
1546
+ *,
1547
+ method: str = ...,
1548
+ bounds_error: bool = ...,
1549
+ fill_value: float | np.float64 | None = ...,
1550
+ localize: bool = ...,
1551
+ lowmem: bool = ...,
1552
+ indices: interpolation.RGIArtifacts | None = ...,
1553
+ return_indices: Literal[True],
1554
+ ) -> tuple[npt.NDArray[np.floating], interpolation.RGIArtifacts]: ...
1555
+
1556
+ def interpolate(
1557
+ self,
1558
+ longitude: float | npt.NDArray[np.floating],
1559
+ latitude: float | npt.NDArray[np.floating],
1560
+ level: float | npt.NDArray[np.floating],
1561
+ time: np.datetime64 | npt.NDArray[np.datetime64],
1562
+ *,
1563
+ method: str = "linear",
1564
+ bounds_error: bool = False,
1565
+ fill_value: float | np.float64 | None = np.nan,
1566
+ localize: bool = False,
1567
+ lowmem: bool = False,
1568
+ indices: interpolation.RGIArtifacts | None = None,
1569
+ return_indices: bool = False,
1570
+ ) -> npt.NDArray[np.floating] | tuple[npt.NDArray[np.floating], interpolation.RGIArtifacts]:
1571
+ """Interpolate values over underlying DataArray.
1572
+
1573
+ Zero dimensional coordinates are reshaped to 1D arrays.
1574
+
1575
+ If ``lowmem == False``, method automatically loads underlying :attr:`data` into
1576
+ memory. Otherwise, method iterates through smaller subsets of :attr:`data` and releases
1577
+ subsets from memory once interpolation against each subset is finished.
1578
+
1579
+ If ``method == "nearest"``, the out array will have the same ``dtype`` as
1580
+ the underlying :attr:`data`.
1581
+
1582
+ If ``method == "linear"``, the out array will be promoted to the most
1583
+ precise ``dtype`` of:
1584
+
1585
+ - underlying :attr:`data`
1586
+ - :attr:`data.longitude`
1587
+ - :attr:`data.latitude`
1588
+ - :attr:`data.level`
1589
+ - ``longitude``
1590
+ - ``latitude``
1591
+
1592
+ .. versionadded:: 0.24
1593
+
1594
+ This method can now handle singleton dimensions with ``method == "linear"``.
1595
+ Previously these degenerate dimensions caused nan values to be returned.
1596
+
1597
+ Parameters
1598
+ ----------
1599
+ longitude : float | npt.NDArray[np.floating]
1600
+ Longitude values to interpolate. Assumed to be 0 or 1 dimensional.
1601
+ latitude : float | npt.NDArray[np.floating]
1602
+ Latitude values to interpolate. Assumed to be 0 or 1 dimensional.
1603
+ level : float | npt.NDArray[np.floating]
1604
+ Level values to interpolate. Assumed to be 0 or 1 dimensional.
1605
+ time : np.datetime64 | npt.NDArray[np.datetime64]
1606
+ Time values to interpolate. Assumed to be 0 or 1 dimensional.
1607
+ method: str, optional
1608
+ Additional keyword arguments to pass to
1609
+ :class:`scipy.interpolate.RegularGridInterpolator`.
1610
+ Defaults to "linear".
1611
+ bounds_error: bool, optional
1612
+ Additional keyword arguments to pass to
1613
+ :class:`scipy.interpolate.RegularGridInterpolator`.
1614
+ Defaults to ``False``.
1615
+ fill_value: float | np.float64, optional
1616
+ Additional keyword arguments to pass to
1617
+ :class:`scipy.interpolate.RegularGridInterpolator`.
1618
+ Set to None to extrapolate outside the boundary when ``method`` is ``nearest``.
1619
+ Defaults to ``np.nan``.
1620
+ localize: bool, optional
1621
+ Experimental. If True, downselect gridded data to smallest bounding box containing
1622
+ all points. By default False.
1623
+ lowmem: bool, optional
1624
+ Experimental. If True, iterate through points binned by the time coordinate of the
1625
+ grided data, and downselect gridded data to the smallest bounding box containing
1626
+ each binned set of point *before loading into memory*. This can significantly reduce
1627
+ memory consumption with large numbers of points at the cost of increased runtime.
1628
+ By default False.
1629
+ indices: tuple | None, optional
1630
+ Experimental. See :func:`interpolation.interp`. None by default.
1631
+ return_indices: bool, optional
1632
+ Experimental. See :func:`interpolation.interp`. False by default.
1633
+ Note that values returned differ when ``lowmem=True`` and ``lowmem=False``,
1634
+ so output should only be re-used in calls with the same ``lowmem`` value.
1635
+
1636
+ Returns
1637
+ -------
1638
+ np.ndarray
1639
+ Interpolated values
1640
+
1641
+ See Also
1642
+ --------
1643
+ :meth:`GeoVectorDataset.intersect_met`
1644
+
1645
+ Examples
1646
+ --------
1647
+ >>> from datetime import datetime
1648
+ >>> import numpy as np
1649
+ >>> import pandas as pd
1650
+ >>> from pycontrails.datalib.ecmwf import ERA5
1651
+
1652
+ >>> times = (datetime(2022, 3, 1, 12), datetime(2022, 3, 1, 15))
1653
+ >>> variables = "air_temperature"
1654
+ >>> levels = [200, 250, 300]
1655
+ >>> era5 = ERA5(times, variables, levels)
1656
+ >>> met = era5.open_metdataset()
1657
+ >>> mda = met["air_temperature"]
1658
+
1659
+ >>> # Interpolation at a grid point agrees with value
1660
+ >>> mda.interpolate(1, 2, 300, np.datetime64('2022-03-01T14:00'))
1661
+ array([241.91972984])
1662
+
1663
+ >>> da = mda.data
1664
+ >>> da.sel(longitude=1, latitude=2, level=300, time=np.datetime64('2022-03-01T14')).item()
1665
+ 241.9197298421629
1666
+
1667
+ >>> # Interpolation off grid
1668
+ >>> mda.interpolate(1.1, 2.1, 290, np.datetime64('2022-03-01 13:10'))
1669
+ array([239.83793798])
1670
+
1671
+ >>> # Interpolate along path
1672
+ >>> longitude = np.linspace(1, 2, 10)
1673
+ >>> latitude = np.linspace(2, 3, 10)
1674
+ >>> level = np.linspace(200, 300, 10)
1675
+ >>> time = pd.date_range("2022-03-01T14", periods=10, freq="5min")
1676
+ >>> mda.interpolate(longitude, latitude, level, time)
1677
+ array([220.44347694, 223.08900738, 225.74338924, 228.41642088,
1678
+ 231.10858599, 233.54857391, 235.71504913, 237.86478872,
1679
+ 239.99274623, 242.10792167])
1680
+
1681
+ >>> # Can easily switch to alternative low-memory implementation
1682
+ >>> mda.interpolate(longitude, latitude, level, time, lowmem=True)
1683
+ array([220.44347694, 223.08900738, 225.74338924, 228.41642088,
1684
+ 231.10858599, 233.54857391, 235.71504913, 237.86478872,
1685
+ 239.99274623, 242.10792167])
1686
+ """
1687
+ if lowmem:
1688
+ return self._interp_lowmem(
1689
+ longitude,
1690
+ latitude,
1691
+ level,
1692
+ time,
1693
+ method=method,
1694
+ bounds_error=bounds_error,
1695
+ fill_value=fill_value,
1696
+ indices=indices,
1697
+ return_indices=return_indices,
1698
+ )
1699
+
1700
+ # Load if necessary
1701
+ if not self.in_memory:
1702
+ self._check_memory("Interpolation over")
1703
+ self.data.load()
1704
+
1705
+ # Convert all inputs to 1d arrays
1706
+ # Not validating against ndim >= 2
1707
+ longitude, latitude, level, time = np.atleast_1d(longitude, latitude, level, time)
1708
+
1709
+ # Pass off to the interp function, which does all the heavy lifting
1710
+ return interpolation.interp(
1711
+ longitude=longitude,
1712
+ latitude=latitude,
1713
+ level=level,
1714
+ time=time,
1715
+ da=self.data,
1716
+ method=method,
1717
+ bounds_error=bounds_error,
1718
+ fill_value=fill_value,
1719
+ localize=localize,
1720
+ indices=indices,
1721
+ return_indices=return_indices,
1722
+ )
1723
+
1724
+ def _interp_lowmem(
1725
+ self,
1726
+ longitude: float | npt.NDArray[np.floating],
1727
+ latitude: float | npt.NDArray[np.floating],
1728
+ level: float | npt.NDArray[np.floating],
1729
+ time: np.datetime64 | npt.NDArray[np.datetime64],
1730
+ *,
1731
+ method: str = "linear",
1732
+ bounds_error: bool = False,
1733
+ fill_value: float | np.float64 | None = np.nan,
1734
+ indices: interpolation.RGIArtifacts | None = None,
1735
+ return_indices: bool = False,
1736
+ ) -> npt.NDArray[np.floating] | tuple[npt.NDArray[np.floating], interpolation.RGIArtifacts]:
1737
+ """Interpolate values against underlying DataArray.
1738
+
1739
+ This method is used by :meth:`interpolate` when ``lowmem=True``.
1740
+ Parameters and return types are identical to :meth:`interpolate`, except
1741
+ that the ``localize`` keyword argument is omitted.
1742
+ """
1743
+ # Convert all inputs to 1d arrays
1744
+ # Not validating against ndim >= 2
1745
+ longitude, latitude, level, time = np.atleast_1d(longitude, latitude, level, time)
1746
+
1747
+ if bounds_error:
1748
+ _lowmem_boundscheck(time, self.data)
1749
+
1750
+ # Create buffers for holding interpolation output
1751
+ # Use np.full rather than np.empty so points not covered
1752
+ # by masks are filled with correct out-of-bounds values.
1753
+ out = np.full(longitude.shape, fill_value, dtype=self.data.dtype)
1754
+ if return_indices:
1755
+ rgi_artifacts = interpolation.RGIArtifacts(
1756
+ xi_indices=np.full((4, longitude.size), -1, dtype=np.int64),
1757
+ norm_distances=np.full((4, longitude.size), np.nan, dtype=np.float64),
1758
+ out_of_bounds=np.full((longitude.size,), True, dtype=np.bool_),
1759
+ )
1760
+
1761
+ # Iterate over portions of points between adjacent time steps in gridded data
1762
+ for mask in _lowmem_masks(time, self.data["time"].values):
1763
+ if mask is None or not np.any(mask):
1764
+ continue
1765
+
1766
+ lon_sl = longitude[mask]
1767
+ lat_sl = latitude[mask]
1768
+ lev_sl = level[mask]
1769
+ t_sl = time[mask]
1770
+ if indices is not None:
1771
+ indices_sl = interpolation.RGIArtifacts(
1772
+ xi_indices=indices.xi_indices[:, mask],
1773
+ norm_distances=indices.norm_distances[:, mask],
1774
+ out_of_bounds=indices.out_of_bounds[mask],
1775
+ )
1776
+ else:
1777
+ indices_sl = None
1778
+
1779
+ coords = {"longitude": lon_sl, "latitude": lat_sl, "level": lev_sl, "time": t_sl}
1780
+ if any(np.all(np.isnan(coord)) for coord in coords.values()):
1781
+ continue
1782
+ da = interpolation._localize(self.data, coords)
1783
+ if not da._in_memory:
1784
+ logger.debug(
1785
+ "Loading %s MB subset of %s into memory.",
1786
+ round(da.nbytes / 1_000_000, 2),
1787
+ da.name,
1788
+ )
1789
+ da.load()
1790
+
1791
+ if return_indices:
1792
+ out[mask], rgi_sl = interpolation.interp(
1793
+ longitude=lon_sl,
1794
+ latitude=lat_sl,
1795
+ level=lev_sl,
1796
+ time=t_sl,
1797
+ da=da,
1798
+ method=method,
1799
+ bounds_error=bounds_error,
1800
+ fill_value=fill_value,
1801
+ localize=False, # would be no-op; da is localized already
1802
+ indices=indices_sl,
1803
+ return_indices=return_indices,
1804
+ )
1805
+ rgi_artifacts.xi_indices[:, mask] = rgi_sl.xi_indices
1806
+ rgi_artifacts.norm_distances[:, mask] = rgi_sl.norm_distances
1807
+ rgi_artifacts.out_of_bounds[mask] = rgi_sl.out_of_bounds
1808
+ else:
1809
+ out[mask] = interpolation.interp(
1810
+ longitude=lon_sl,
1811
+ latitude=lat_sl,
1812
+ level=lev_sl,
1813
+ time=t_sl,
1814
+ da=da,
1815
+ method=method,
1816
+ bounds_error=bounds_error,
1817
+ fill_value=fill_value,
1818
+ localize=False, # would be no-op; da is localized already
1819
+ indices=indices_sl,
1820
+ return_indices=return_indices,
1821
+ )
1822
+
1823
+ if return_indices:
1824
+ return out, rgi_artifacts
1825
+ return out
1826
+
1827
+ def _check_memory(self, msg_start: str) -> None:
1828
+ """Check the memory usage of the underlying data.
1829
+
1830
+ If the data is larger than 4 GB, a warning is issued. If the data is
1831
+ larger than 32 GB, a RuntimeError is raised.
1832
+ """
1833
+ n_bytes = self.data.nbytes
1834
+ mb = round(n_bytes / int(1e6), 2)
1835
+ logger.debug("Loading %s into memory consumes %s MB.", self.name, mb)
1836
+
1837
+ n_gb = n_bytes // int(1e9)
1838
+ if n_gb <= 4:
1839
+ return
1840
+
1841
+ # Prevent something stupid
1842
+ msg = (
1843
+ f"{msg_start} MetDataArray {self.name} requires loading "
1844
+ f"at least {n_gb} GB of data into memory. Downselect data if possible. "
1845
+ "If working with a GeoVectorDataset instance, this can be achieved "
1846
+ "with the method 'downselect_met'."
1847
+ )
1848
+
1849
+ if n_gb > 32:
1850
+ raise RuntimeError(msg)
1851
+ warnings.warn(msg)
1852
+
1853
+ def save(self, **kwargs: Any) -> list[str]:
1854
+ """Save intermediate to :attr:`cachestore` as netcdf.
1855
+
1856
+ Load and restore using :meth:`load`.
1857
+
1858
+ Parameters
1859
+ ----------
1860
+ **kwargs : Any
1861
+ Keyword arguments passed directly to :func:`xarray.save_mfdataset`
1862
+
1863
+ Returns
1864
+ -------
1865
+ list[str]
1866
+ Returns filenames of saved files
1867
+ """
1868
+ dataset = self.data.to_dataset()
1869
+ return self._save(dataset, **kwargs)
1870
+
1871
+ @classmethod
1872
+ def load(
1873
+ cls,
1874
+ hash: str,
1875
+ cachestore: CacheStore | None = None,
1876
+ chunks: dict[str, int] | None = None,
1877
+ ) -> Self:
1878
+ """Load saved intermediate from :attr:`cachestore`.
1879
+
1880
+ Parameters
1881
+ ----------
1882
+ hash : str
1883
+ Saved hash to load.
1884
+ cachestore : CacheStore, optional
1885
+ Cache datastore to use for sourcing files.
1886
+ Defaults to DiskCacheStore.
1887
+ chunks : dict[str, int], optional
1888
+ Chunks kwarg passed to :func:`xarray.open_mfdataset()` when opening files.
1889
+
1890
+ Returns
1891
+ -------
1892
+ MetDataArray
1893
+ New MetDataArray with loaded data.
1894
+ """
1895
+ cachestore = cachestore or DiskCacheStore()
1896
+ chunks = chunks or {}
1897
+ data = _load(hash, cachestore, chunks)
1898
+ return cls(data[next(iter(data.data_vars))])
1899
+
1900
+ @property
1901
+ def proportion(self) -> float:
1902
+ """Compute proportion of points with value 1.
1903
+
1904
+ Returns
1905
+ -------
1906
+ float
1907
+ Proportion of points with value 1
1908
+
1909
+ Raises
1910
+ ------
1911
+ NotImplementedError
1912
+ If instance does not contain binary data.
1913
+ """
1914
+ if not self.binary:
1915
+ raise NotImplementedError("proportion method is only implemented for binary fields")
1916
+
1917
+ return self.data.sum().values.item() / self.data.count().values.item()
1918
+
1919
+ def find_edges(self) -> Self:
1920
+ """Find edges of regions.
1921
+
1922
+ Returns
1923
+ -------
1924
+ Self
1925
+ MetDataArray with a binary field, 1 on the edge of the regions,
1926
+ 0 outside and inside the regions.
1927
+
1928
+ Raises
1929
+ ------
1930
+ NotImplementedError
1931
+ If the instance is not binary.
1932
+ """
1933
+ if not self.binary:
1934
+ raise NotImplementedError("find_edges method is only implemented for binary fields")
1935
+
1936
+ # edge detection algorithm using differentiation to reduce the areas to lines
1937
+ def _edges(da: xr.DataArray) -> xr.DataArray:
1938
+ lat_diff = da.differentiate("latitude")
1939
+ lon_diff = da.differentiate("longitude")
1940
+ diff = da.where((lat_diff != 0) | (lon_diff != 0), 0)
1941
+
1942
+ # TODO: what is desired behavior here?
1943
+ # set boundaries to close contour regions
1944
+ diff[dict(longitude=0)] = da[dict(longitude=0)]
1945
+ diff[dict(longitude=-1)] = da[dict(longitude=-1)]
1946
+ diff[dict(latitude=0)] = da[dict(latitude=0)]
1947
+ diff[dict(latitude=-1)] = da[dict(latitude=-1)]
1948
+
1949
+ return diff
1950
+
1951
+ # load data into memory (required for value assignment in _edges()
1952
+ self.data.load()
1953
+
1954
+ data = self.data.groupby("level", squeeze=False).map(_edges)
1955
+ return type(self)(data, cachestore=self.cachestore)
1956
+
1957
+ def to_polygon_feature(
1958
+ self,
1959
+ level: float | int | None = None,
1960
+ time: np.datetime64 | datetime | None = None,
1961
+ fill_value: float = np.nan,
1962
+ iso_value: float | None = None,
1963
+ min_area: float = 0.0,
1964
+ epsilon: float = 0.0,
1965
+ lower_bound: bool = True,
1966
+ precision: int | None = None,
1967
+ interiors: bool = True,
1968
+ convex_hull: bool = False,
1969
+ include_altitude: bool = False,
1970
+ properties: dict[str, Any] | None = None,
1971
+ ) -> dict[str, Any]:
1972
+ """Create GeoJSON Feature artifact from spatial array on a single level and time slice.
1973
+
1974
+ Computed polygons always contain an exterior linear ring as defined by the
1975
+ `GeoJSON Polygon specification <https://www.rfc-editor.org/rfc/rfc7946.html#section-3.1.6>`.
1976
+ Polygons may also contain interior linear rings (holes). This method does not support
1977
+ nesting beyond the GeoJSON specification. See the :mod:`pycontrails.core.polygon`
1978
+ for additional polygon support.
1979
+
1980
+ .. versionchanged:: 0.25.12
1981
+
1982
+ Previous implementation include several additional parameters which have
1983
+ been removed:
1984
+
1985
+ - The ``approximate`` parameter
1986
+ - An ``path`` parameter to save output as JSON
1987
+ - Passing arbitrary kwargs to :func:`skimage.measure.find_contours`.
1988
+
1989
+ New implementation includes new parameters previously lacking:
1990
+
1991
+ - ``fill_value``
1992
+ - ``min_area``
1993
+ - ``include_altitude``
1994
+
1995
+ .. versionchanged:: 0.38.0
1996
+
1997
+ Change default value of ``epsilon`` from 0.15 to 0.
1998
+
1999
+ .. versionchanged:: 0.41.0
2000
+
2001
+ Convert continuous fields to binary fields before computing polygons.
2002
+ The parameters ``max_area`` and ``epsilon`` are now expressed in terms of
2003
+ longitude/latitude units instead of pixels.
2004
+
2005
+ Parameters
2006
+ ----------
2007
+ level : float, optional
2008
+ Level slice to create polygons.
2009
+ If the "level" coordinate is length 1, then the single level slice will be selected
2010
+ automatically.
2011
+ time : datetime, optional
2012
+ Time slice to create polygons.
2013
+ If the "time" coordinate is length 1, then the single time slice will be selected
2014
+ automatically.
2015
+ fill_value : float, optional
2016
+ Value used for filling missing data and for padding the underlying data array.
2017
+ Set to ``np.nan`` by default, which ensures that regions with missing data are
2018
+ never included in polygons.
2019
+ iso_value : float, optional
2020
+ Value in field to create iso-surface.
2021
+ Defaults to the average of the min and max value of the array. (This is the
2022
+ same convention as used by ``skimage``.)
2023
+ min_area : float, optional
2024
+ Minimum area of each polygon. Polygons with area less than ``min_area`` are
2025
+ not included in the output. The unit of this parameter is in longitude/latitude
2026
+ degrees squared. Set to 0 to omit any polygon filtering based on a minimal area
2027
+ conditional. By default, 0.0.
2028
+ epsilon : float, optional
2029
+ Control the extent to which the polygon is simplified. A value of 0 does not alter
2030
+ the geometry of the polygon. The unit of this parameter is in longitude/latitude
2031
+ degrees. By default, 0.0.
2032
+ lower_bound : bool, optional
2033
+ Whether to use ``iso_value`` as a lower or upper bound on values in polygon interiors.
2034
+ By default, True.
2035
+ precision : int, optional
2036
+ Number of decimal places to round coordinates to. If None, no rounding is performed.
2037
+ interiors : bool, optional
2038
+ If True, include interior linear rings (holes) in the output. True by default.
2039
+ convex_hull : bool, optional
2040
+ EXPERIMENTAL. If True, compute the convex hull of each polygon. Only implemented
2041
+ for depth=1. False by default. A warning is issued if the underlying algorithm
2042
+ fails to make valid polygons after computing the convex hull.
2043
+ include_altitude : bool, optional
2044
+ If True, include the array altitude [:math:`m`] as a z-coordinate in the
2045
+ `GeoJSON output <https://www.rfc-editor.org/rfc/rfc7946#section-3.1.1>`.
2046
+ False by default.
2047
+ properties : dict, optional
2048
+ Additional properties to include in the GeoJSON output. By default, None.
2049
+
2050
+ Returns
2051
+ -------
2052
+ dict[str, Any]
2053
+ Python representation of GeoJSON Feature with MultiPolygon geometry.
2054
+
2055
+ Notes
2056
+ -----
2057
+ :class:`Cocip` and :class:`CocipGrid` set some quantities to 0 and other quantities
2058
+ to ``np.nan`` in regions where no contrails form. When computing polygons from
2059
+ :class:`Cocip` or :class:`CocipGrid` output, take care that the choice of
2060
+ ``fill_value`` correctly includes or excludes contrail-free regions. See the
2061
+ :class:`Cocip` documentation for details about ``np.nan`` in model output.
2062
+
2063
+ See Also
2064
+ --------
2065
+ :meth:`to_polyhedra`
2066
+ :func:`polygons.find_multipolygons`
2067
+
2068
+ Examples
2069
+ --------
2070
+ >>> from pprint import pprint
2071
+ >>> from pycontrails.datalib.ecmwf import ERA5
2072
+ >>> era5 = ERA5("2022-03-01", variables="air_temperature", pressure_levels=250)
2073
+ >>> mda = era5.open_metdataset()["air_temperature"]
2074
+ >>> mda.shape
2075
+ (1440, 721, 1, 1)
2076
+
2077
+ >>> pprint(mda.to_polygon_feature(iso_value=239.5, precision=2, epsilon=0.1))
2078
+ {'geometry': {'coordinates': [[[[167.88, -22.5],
2079
+ [167.75, -22.38],
2080
+ [167.62, -22.5],
2081
+ [167.75, -22.62],
2082
+ [167.88, -22.5]]],
2083
+ [[[43.38, -33.5],
2084
+ [43.5, -34.12],
2085
+ [43.62, -33.5],
2086
+ [43.5, -33.38],
2087
+ [43.38, -33.5]]]],
2088
+ 'type': 'MultiPolygon'},
2089
+ 'properties': {},
2090
+ 'type': 'Feature'}
2091
+
2092
+ """
2093
+ if convex_hull and interiors:
2094
+ raise ValueError("Set 'interiors=False' to use the 'convex_hull' parameter.")
2095
+
2096
+ arr, altitude = _extract_2d_arr_and_altitude(self, level, time)
2097
+ if not include_altitude:
2098
+ altitude = None # this logic used below
2099
+ elif altitude is None:
2100
+ raise ValueError(
2101
+ "The parameter 'include_altitude' is True, but altitude is not "
2102
+ "found on MetDataArray instance. Either set altitude, or pass "
2103
+ "include_altitude=False."
2104
+ )
2105
+
2106
+ if not np.isnan(fill_value):
2107
+ np.nan_to_num(arr, copy=False, nan=fill_value)
2108
+
2109
+ # default iso_value
2110
+ if iso_value is None:
2111
+ iso_value = (np.nanmax(arr) + np.nanmin(arr)) / 2
2112
+ warnings.warn(f"The 'iso_value' parameter was not specified. Using value: {iso_value}")
2113
+
2114
+ # We'll get a nice error message if dependencies are not installed
2115
+ from pycontrails.core import polygon
2116
+
2117
+ # Convert to nested lists of coordinates for GeoJSON representation
2118
+ indexes = self.indexes
2119
+ longitude = indexes["longitude"].to_numpy()
2120
+ latitude = indexes["latitude"].to_numpy()
2121
+
2122
+ mp = polygon.find_multipolygon(
2123
+ arr,
2124
+ threshold=iso_value,
2125
+ min_area=min_area,
2126
+ epsilon=epsilon,
2127
+ lower_bound=lower_bound,
2128
+ interiors=interiors,
2129
+ convex_hull=convex_hull,
2130
+ longitude=longitude,
2131
+ latitude=latitude,
2132
+ precision=precision,
2133
+ )
2134
+
2135
+ return polygon.multipolygon_to_geojson(mp, altitude, properties)
2136
+
2137
+ def to_polygon_feature_collection(
2138
+ self,
2139
+ time: np.datetime64 | datetime | None = None,
2140
+ fill_value: float = np.nan,
2141
+ iso_value: float | None = None,
2142
+ min_area: float = 0.0,
2143
+ epsilon: float = 0.0,
2144
+ lower_bound: bool = True,
2145
+ precision: int | None = None,
2146
+ interiors: bool = True,
2147
+ convex_hull: bool = False,
2148
+ include_altitude: bool = False,
2149
+ properties: dict[str, Any] | None = None,
2150
+ ) -> dict[str, Any]:
2151
+ """Create GeoJSON FeatureCollection artifact from spatial array at time slice.
2152
+
2153
+ See the :meth:`to_polygon_feature` method for a description of the parameters.
2154
+
2155
+ Returns
2156
+ -------
2157
+ dict[str, Any]
2158
+ Python representation of GeoJSON FeatureCollection. This dictionary is
2159
+ comprised of individual GeoJON Features, one per :attr:`self.data["level"]`.
2160
+ """
2161
+ base_properties = properties or {}
2162
+ features = []
2163
+ for level in self.data["level"]:
2164
+ properties = base_properties.copy()
2165
+ properties.update(level=level.item())
2166
+ properties.update({f"level_{k}": v for k, v in self.data["level"].attrs.items()})
2167
+
2168
+ feature = self.to_polygon_feature(
2169
+ level=level,
2170
+ time=time,
2171
+ fill_value=fill_value,
2172
+ iso_value=iso_value,
2173
+ min_area=min_area,
2174
+ epsilon=epsilon,
2175
+ lower_bound=lower_bound,
2176
+ precision=precision,
2177
+ interiors=interiors,
2178
+ convex_hull=convex_hull,
2179
+ include_altitude=include_altitude,
2180
+ properties=properties,
2181
+ )
2182
+ features.append(feature)
2183
+
2184
+ return {
2185
+ "type": "FeatureCollection",
2186
+ "features": features,
2187
+ }
2188
+
2189
+ @overload
2190
+ def to_polyhedra(
2191
+ self,
2192
+ *,
2193
+ time: datetime | None = ...,
2194
+ iso_value: float = ...,
2195
+ simplify_fraction: float = ...,
2196
+ lower_bound: bool = ...,
2197
+ return_type: Literal["geojson"],
2198
+ path: str | None = ...,
2199
+ altitude_scale: float = ...,
2200
+ output_vertex_normals: bool = ...,
2201
+ closed: bool = ...,
2202
+ ) -> dict: ...
2203
+
2204
+ @overload
2205
+ def to_polyhedra(
2206
+ self,
2207
+ *,
2208
+ time: datetime | None = ...,
2209
+ iso_value: float = ...,
2210
+ simplify_fraction: float = ...,
2211
+ lower_bound: bool = ...,
2212
+ return_type: Literal["mesh"],
2213
+ path: str | None = ...,
2214
+ altitude_scale: float = ...,
2215
+ output_vertex_normals: bool = ...,
2216
+ closed: bool = ...,
2217
+ ) -> o3d.geometry.TriangleMesh: ...
2218
+
2219
+ def to_polyhedra(
2220
+ self,
2221
+ *,
2222
+ time: datetime | None = None,
2223
+ iso_value: float = 0.0,
2224
+ simplify_fraction: float = 1.0,
2225
+ lower_bound: bool = True,
2226
+ return_type: str = "geojson",
2227
+ path: str | None = None,
2228
+ altitude_scale: float = 1.0,
2229
+ output_vertex_normals: bool = False,
2230
+ closed: bool = True,
2231
+ ) -> dict | o3d.geometry.TriangleMesh:
2232
+ """Create a collection of polyhedra from spatial array corresponding to a single time slice.
2233
+
2234
+ Parameters
2235
+ ----------
2236
+ time : datetime, optional
2237
+ Time slice to create mesh.
2238
+ If the "time" coordinate is length 1, then the single time slice will be selected
2239
+ automatically.
2240
+ iso_value : float, optional
2241
+ Value in field to create iso-surface. Defaults to 0.
2242
+ simplify_fraction : float, optional
2243
+ Apply `open3d` `simplify_quadric_decimation` method to simplify the polyhedra geometry.
2244
+ This parameter must be in the half-open interval (0.0, 1.0].
2245
+ Defaults to 1.0, corresponding to no reduction.
2246
+ lower_bound : bool, optional
2247
+ Whether to use ``iso_value`` as a lower or upper bound on values in polyhedra interiors.
2248
+ By default, True.
2249
+ return_type : str, optional
2250
+ Must be one of "geojson" or "mesh". Defaults to "geojson".
2251
+ If "geojson", this method returns a dictionary representation of a geojson MultiPolygon
2252
+ object whose polygons are polyhedra faces.
2253
+ If "mesh", this method returns an `open3d` `TriangleMesh` instance.
2254
+ path : str, optional
2255
+ Output geojson or mesh to file.
2256
+ If `return_type` is "mesh", see `Open3D File I/O for Mesh
2257
+ <http://www.open3d.org/docs/release/tutorial/geometry/file_io.html#Mesh>`_ for
2258
+ file type options.
2259
+ altitude_scale : float, optional
2260
+ Rescale the altitude dimension of the mesh, [:math:`m`]
2261
+ output_vertex_normals : bool, optional
2262
+ If ``path`` is defined, write out vertex normals.
2263
+ Defaults to False.
2264
+ closed : bool, optional
2265
+ If True, pad spatial array along all axes to ensure polyhedra are "closed".
2266
+ This flag often gives rise to cleaner visualizations. Defaults to True.
2267
+
2268
+ Returns
2269
+ -------
2270
+ dict | open3d.geometry.TriangleMesh
2271
+ Python representation of geojson object or `Open3D Triangle Mesh
2272
+ <http://www.open3d.org/docs/release/tutorial/geometry/mesh.html>`_ depending on the
2273
+ `return_type` parameter.
2274
+
2275
+ Raises
2276
+ ------
2277
+ ModuleNotFoundError
2278
+ Method requires the `vis` optional dependencies
2279
+ ValueError
2280
+ If input parameters are invalid.
2281
+
2282
+ See Also
2283
+ --------
2284
+ :meth:`to_polygon_feature`
2285
+ :func:`skimage.measure.marching_cubes`
2286
+ :class:`open3d.geometry.TriangleMesh`
2287
+
2288
+ Notes
2289
+ -----
2290
+ Uses the `scikit-image Marching Cubes <https://scikit-image.org/docs/dev/auto_examples/edges/plot_marching_cubes.html>`_
2291
+ algorithm to reconstruct a surface from the point-cloud like arrays.
2292
+ """
2293
+ try:
2294
+ from skimage import measure
2295
+ except ModuleNotFoundError as e:
2296
+ dependencies.raise_module_not_found_error(
2297
+ name="MetDataArray.to_polyhedra method",
2298
+ package_name="scikit-image",
2299
+ pycontrails_optional_package="vis",
2300
+ module_not_found_error=e,
2301
+ )
2302
+
2303
+ try:
2304
+ import open3d as o3d
2305
+ except ModuleNotFoundError as e:
2306
+ dependencies.raise_module_not_found_error(
2307
+ name="MetDataArray.to_polyhedra method",
2308
+ package_name="open3d",
2309
+ pycontrails_optional_package="open3d",
2310
+ module_not_found_error=e,
2311
+ )
2312
+
2313
+ if len(self.data["level"]) == 1:
2314
+ raise ValueError(
2315
+ "Found single `level` coordinate in DataArray. This method requires at least two."
2316
+ )
2317
+
2318
+ # select time
2319
+ if time is None and len(self.data["time"]) == 1:
2320
+ time = self.data["time"].values[0]
2321
+
2322
+ if time is None:
2323
+ raise ValueError(
2324
+ "time input must be defined when the length of the time coordinates are > 1"
2325
+ )
2326
+
2327
+ if simplify_fraction > 1 or simplify_fraction <= 0:
2328
+ raise ValueError("Parameter `simplify_fraction` must be in the interval (0, 1].")
2329
+
2330
+ return_types = ["geojson", "mesh"]
2331
+ if return_type not in return_types:
2332
+ raise ValueError(f"Parameter `return_type` must be one of {', '.join(return_types)}")
2333
+
2334
+ # 3d array of longitude, latitude, altitude values
2335
+ volume = self.data.sel(time=time).values
2336
+
2337
+ # invert if iso_value is an upper bound on interior values
2338
+ if not lower_bound:
2339
+ volume = -volume
2340
+ iso_value = -iso_value
2341
+
2342
+ # convert from array index back to coordinates
2343
+ longitude = self.indexes["longitude"].values
2344
+ latitude = self.indexes["latitude"].values
2345
+ altitude = units.pl_to_m(self.indexes["level"].values)
2346
+
2347
+ # Pad volume on all axes to close the volumes
2348
+ if closed:
2349
+ # pad values to domain
2350
+ longitude0 = longitude[0] - (longitude[1] - longitude[0])
2351
+ longitude1 = longitude[-1] + longitude[-1] - longitude[-2]
2352
+ longitude = np.pad(longitude, pad_width=1, constant_values=(longitude0, longitude1))
2353
+
2354
+ latitude0 = latitude[0] - (latitude[1] - latitude[0])
2355
+ latitude1 = latitude[-1] + latitude[-1] - latitude[-2]
2356
+ latitude = np.pad(latitude, pad_width=1, constant_values=(latitude0, latitude1))
2357
+
2358
+ altitude0 = altitude[0] - (altitude[1] - altitude[0])
2359
+ altitude1 = altitude[-1] + altitude[-1] - altitude[-2]
2360
+ altitude = np.pad(altitude, pad_width=1, constant_values=(altitude0, altitude1))
2361
+
2362
+ # Pad along axes to ensure polygons are closed
2363
+ volume = np.pad(volume, pad_width=1, constant_values=iso_value)
2364
+
2365
+ # Use marching cubes to obtain the surface mesh
2366
+ # Coordinates of verts are indexes of volume array
2367
+ verts, faces, normals, _ = measure.marching_cubes(
2368
+ volume, iso_value, allow_degenerate=False, gradient_direction="ascent"
2369
+ )
2370
+
2371
+ # Convert from indexes to longitude, latitude, altitude values
2372
+ verts[:, 0] = longitude[verts[:, 0].astype(int)]
2373
+ verts[:, 1] = latitude[verts[:, 1].astype(int)]
2374
+ verts[:, 2] = altitude[verts[:, 2].astype(int)]
2375
+
2376
+ # rescale altitude
2377
+ verts[:, 2] = verts[:, 2] * altitude_scale
2378
+
2379
+ # create mesh in open3d
2380
+ mesh = o3d.geometry.TriangleMesh()
2381
+ mesh.vertices = o3d.utility.Vector3dVector(verts)
2382
+ mesh.triangles = o3d.utility.Vector3iVector(faces)
2383
+ mesh.vertex_normals = o3d.utility.Vector3dVector(normals)
2384
+
2385
+ # simplify mesh according to sim
2386
+ if simplify_fraction < 1:
2387
+ target_n_triangles = int(faces.shape[0] * simplify_fraction)
2388
+ mesh = mesh.simplify_quadric_decimation(target_number_of_triangles=target_n_triangles)
2389
+ mesh.compute_vertex_normals()
2390
+
2391
+ if path is not None:
2392
+ path = str(pathlib.Path(path).absolute())
2393
+
2394
+ if return_type == "geojson":
2395
+ verts = np.round(
2396
+ np.asarray(mesh.vertices), decimals=4
2397
+ ) # rounding to reduce the size of resultant json arrays
2398
+ faces = np.asarray(mesh.triangles)
2399
+
2400
+ # TODO: technically this is not valid GeoJSON because each polygon (triangle)
2401
+ # does not have the last element equal to the first (not a linear ring)
2402
+ # but it still works for now in Deck.GL
2403
+ coords = [[verts[face].tolist()] for face in faces]
2404
+
2405
+ geojson = {
2406
+ "type": "Feature",
2407
+ "properties": {},
2408
+ "geometry": {"type": "MultiPolygon", "coordinates": coords},
2409
+ }
2410
+
2411
+ if path is not None:
2412
+ if not path.endswith(".json"):
2413
+ path += ".json"
2414
+ with open(path, "w") as file:
2415
+ json.dump(geojson, file)
2416
+ return geojson
2417
+
2418
+ if path is not None:
2419
+ o3d.io.write_triangle_mesh(
2420
+ path,
2421
+ mesh,
2422
+ write_ascii=False,
2423
+ compressed=True,
2424
+ write_vertex_normals=output_vertex_normals,
2425
+ write_vertex_colors=False,
2426
+ write_triangle_uvs=True,
2427
+ print_progress=False,
2428
+ )
2429
+ return mesh
2430
+
2431
+ @override
2432
+ def broadcast_coords(self, name: str) -> xr.DataArray:
2433
+ da = xr.ones_like(self.data) * self.data[name]
2434
+ da.name = name
2435
+
2436
+ return da
2437
+
2438
+
2439
+ def _is_wrapped(longitude: np.ndarray) -> bool:
2440
+ """Check if ``longitude`` covers ``[-180, 180]``."""
2441
+ return longitude[0] <= -180.0 and longitude[-1] >= 180.0
2442
+
2443
+
2444
+ def _is_zarr(ds: xr.Dataset | xr.DataArray) -> bool:
2445
+ """Check if ``ds`` appears to be Zarr-based.
2446
+
2447
+ Neither ``xarray`` nor ``dask`` readily expose such information, so
2448
+ implementation is very narrow and brittle.
2449
+ """
2450
+ if isinstance(ds, xr.Dataset):
2451
+ # Attempt 1: xarray binds the zarr close function to the instance
2452
+ # This gives us an indicator of the data is zarr based
2453
+ try:
2454
+ if ds._close.__func__.__qualname__ == "ZarrStore.close": # type: ignore
2455
+ return True
2456
+ except AttributeError:
2457
+ pass
2458
+
2459
+ # Grab the first data variable, get underlying DataArray
2460
+ ds = ds[next(iter(ds.data_vars))]
2461
+
2462
+ # Attempt 2: Examine the dask graph
2463
+ darr = ds.variable._data # dask array in some cases
2464
+ try:
2465
+ # Get the first dask instruction
2466
+ dask0 = darr.dask[next(iter(darr.dask))] # type: ignore[union-attr]
2467
+ except AttributeError:
2468
+ return False
2469
+ return dask0.array.array.array.__class__.__name__ == "ZarrArrayWrapper"
2470
+
2471
+
2472
+ def shift_longitude(data: XArrayType, bound: float = -180.0) -> XArrayType:
2473
+ """Shift longitude values from any input domain to [bound, 360 + bound) domain.
2474
+
2475
+ Sorts data by ascending longitude values.
2476
+
2477
+
2478
+ Parameters
2479
+ ----------
2480
+ data : XArrayType
2481
+ :class:`xr.Dataset` or :class:`xr.DataArray` with longitude dimension
2482
+ bound : float, optional
2483
+ Lower bound of the domain.
2484
+ Output domain will be [bound, 360 + bound).
2485
+ Defaults to -180, which results in longitude domain [-180, 180).
2486
+
2487
+
2488
+ Returns
2489
+ -------
2490
+ XArrayType
2491
+ :class:`xr.Dataset` or :class:`xr.DataArray` with longitude values on [a, 360 + a).
2492
+ """
2493
+ return data.assign_coords(
2494
+ longitude=((data["longitude"].values - bound) % 360.0) + bound
2495
+ ).sortby("longitude", ascending=True)
2496
+
2497
+
2498
+ def _wrap_longitude(data: XArrayType) -> XArrayType:
2499
+ """Wrap longitude grid coordinates.
2500
+
2501
+ This function assumes the longitude dimension on ``data``:
2502
+
2503
+ - is sorted
2504
+ - is contained in [-180, 180]
2505
+
2506
+ These assumptions are checked by :class:`MetDataset` and :class`MetDataArray`
2507
+ constructors.
2508
+
2509
+ .. versionchanged:: 0.26.0
2510
+
2511
+ This function now ensures every value in the interval ``[-180, 180]``
2512
+ is covered by the longitude dimension of the returned object. See
2513
+ :meth:`MetDataset.is_wrapped` for more details.
2514
+
2515
+
2516
+ Parameters
2517
+ ----------
2518
+ data : XArrayType
2519
+ :class:`xr.Dataset` or :class:`xr.DataArray` with longitude dimension
2520
+
2521
+ Returns
2522
+ -------
2523
+ XArrayType
2524
+ Copy of :class:`xr.Dataset` or :class:`xr.DataArray` with wrapped longitude values.
2525
+
2526
+ Raises
2527
+ ------
2528
+ ValueError
2529
+ If longitude values are already wrapped.
2530
+ """
2531
+ lon = data._indexes["longitude"].index.to_numpy() # type: ignore[attr-defined]
2532
+ if _is_wrapped(lon):
2533
+ raise ValueError("Longitude values are already wrapped")
2534
+
2535
+ lon0 = lon[0]
2536
+ lon1 = lon[-1]
2537
+
2538
+ # Try to prevent something stupid
2539
+ if lon1 - lon0 < 330.0:
2540
+ warnings.warn("Wrapping longitude will create a large spatial gap of more than 30 degrees.")
2541
+
2542
+ objs = [data]
2543
+ if lon0 > -180.0: # if the lowest longitude is not low enough, duplicate highest
2544
+ dup1 = data.sel(longitude=[lon1]).assign_coords(longitude=[lon1 - 360.0])
2545
+ objs.insert(0, dup1)
2546
+ if lon1 < 180.0: # if the highest longitude is not highest enough, duplicate lowest
2547
+ dup0 = data.sel(longitude=[lon0]).assign_coords(longitude=[lon0 + 360.0])
2548
+ objs.append(dup0)
2549
+
2550
+ # Because we explicitly raise a ValueError if longitude already wrapped,
2551
+ # we know that len(objs) > 1, so the concatenation here is nontrivial
2552
+ wrapped = xr.concat(objs, dim="longitude")
2553
+ wrapped["longitude"] = wrapped["longitude"].astype(lon.dtype, copy=False)
2554
+
2555
+ # If there is only one longitude chunk in parameter data, increment
2556
+ # data.chunks can be None, using getattr for extra safety
2557
+ # NOTE: This probably doesn't seem to play well with Zarr data ...
2558
+ # we don't want to be rechunking them.
2559
+ # Ideally we'd raise if data was Zarr-based
2560
+ chunks = getattr(data, "chunks", None) or {}
2561
+ chunks = dict(chunks) # chunks can be frozen
2562
+ lon_chunks = chunks.get("longitude", ())
2563
+ if len(lon_chunks) == 1:
2564
+ chunks["longitude"] = (lon_chunks[0] + len(objs) - 1,)
2565
+ wrapped = wrapped.chunk(chunks)
2566
+
2567
+ return wrapped
2568
+
2569
+
2570
+ def _extract_2d_arr_and_altitude(
2571
+ mda: MetDataArray,
2572
+ level: float | int | None,
2573
+ time: np.datetime64 | datetime | None,
2574
+ ) -> tuple[np.ndarray, float | None]:
2575
+ """Extract underlying 2D array indexed by longitude and latitude.
2576
+
2577
+ Parameters
2578
+ ----------
2579
+ mda : MetDataArray
2580
+ MetDataArray to extract from
2581
+ level : float | int | None
2582
+ Pressure level to slice at
2583
+ time : np.datetime64 | datetime
2584
+ Time to slice at
2585
+
2586
+ Returns
2587
+ -------
2588
+ arr : np.ndarray
2589
+ Copy of 2D array at given level and time
2590
+ altitude : float | None
2591
+ Altitude of slice [:math:`m`]. None if "altitude" not found on ``mda``
2592
+ (ie, for surface level :class:`MetDataArray`).
2593
+ """
2594
+ # Determine level if not specified
2595
+ if level is None:
2596
+ level_coord = mda.indexes["level"].values
2597
+ if len(level_coord) == 1:
2598
+ level = level_coord[0]
2599
+ else:
2600
+ raise ValueError(
2601
+ "Parameter 'level' must be defined when the length of the 'level' "
2602
+ "coordinates is not 1."
2603
+ )
2604
+
2605
+ # Determine time if not specified
2606
+ if time is None:
2607
+ time_coord = mda.indexes["time"].values
2608
+ if len(time_coord) == 1:
2609
+ time = time_coord[0]
2610
+ else:
2611
+ raise ValueError(
2612
+ "Parameter 'time' must be defined when the length of the 'time' "
2613
+ "coordinates is not 1."
2614
+ )
2615
+
2616
+ da = mda.data.sel(level=level, time=time)
2617
+ arr = da.values.copy()
2618
+ if arr.ndim != 2:
2619
+ raise RuntimeError("Malformed data array")
2620
+
2621
+ try:
2622
+ altitude = da["altitude"].values.item() # item not implemented on dask arrays
2623
+ except KeyError:
2624
+ altitude = None
2625
+ else:
2626
+ altitude = round(altitude)
2627
+
2628
+ return arr, altitude
2629
+
2630
+
2631
+ def downselect(data: XArrayType, bbox: tuple[float, ...]) -> XArrayType:
2632
+ """Downselect :class:`xr.Dataset` or :class:`xr.DataArray` with spatial bounding box.
2633
+
2634
+ Parameters
2635
+ ----------
2636
+ data : XArrayType
2637
+ xr.Dataset or xr.DataArray to downselect
2638
+ bbox : tuple[float, ...]
2639
+ Tuple of coordinates defining a spatial bounding box in WGS84 coordinates.
2640
+
2641
+ - For 2D queries, ``bbox`` takes the form ``(west, south, east, north)``
2642
+ - For 3D queries, ``bbox`` takes the form
2643
+ ``(west, south, min-level, east, north, max-level)``
2644
+
2645
+ with level defined in [:math:`hPa`].
2646
+
2647
+ Returns
2648
+ -------
2649
+ XArrayType
2650
+ Downselected xr.Dataset or xr.DataArray
2651
+
2652
+ Raises
2653
+ ------
2654
+ ValueError
2655
+ If parameter ``bbox`` has wrong length.
2656
+ """
2657
+ if len(bbox) == 4:
2658
+ west, south, east, north = bbox
2659
+ level_min = -np.inf
2660
+ level_max = np.inf
2661
+
2662
+ elif len(bbox) == 6:
2663
+ west, south, level_min, east, north, level_max = bbox
2664
+
2665
+ else:
2666
+ raise ValueError(
2667
+ f"bbox {bbox} is not length 4 [west, south, east, north] "
2668
+ "or length 6 [west, south, min-level, east, north, max-level]"
2669
+ )
2670
+
2671
+ if west <= east:
2672
+ # Return a view of the data
2673
+ # If data is lazy, this will not load the data
2674
+ return data.sel(
2675
+ longitude=slice(west, east),
2676
+ latitude=slice(south, north),
2677
+ level=slice(level_min, level_max),
2678
+ )
2679
+
2680
+ # In this case, the bbox spans the antimeridian
2681
+ # If data is lazy, this will load the data (data.where is not lazy AFAIK)
2682
+ cond = (
2683
+ (data["latitude"] >= south)
2684
+ & (data["latitude"] <= north)
2685
+ & (data["level"] >= level_min)
2686
+ & (data["level"] <= level_max)
2687
+ & ((data["longitude"] >= west) | (data["longitude"] <= east))
2688
+ )
2689
+ return data.where(cond, drop=True)
2690
+
2691
+
2692
+ def standardize_variables(ds: xr.Dataset, variables: Iterable[MetVariable]) -> xr.Dataset:
2693
+ """Rename all variables in dataset from short name to standard name.
2694
+
2695
+ This function does not change any variables in ``ds`` that are not found in ``variables``.
2696
+
2697
+ When there are multiple variables with the same short name, the last one is used.
2698
+
2699
+ Parameters
2700
+ ----------
2701
+ ds : DatasetType
2702
+ An :class:`xr.Dataset`.
2703
+ variables : Iterable[MetVariable]
2704
+ Data source variables
2705
+
2706
+ Returns
2707
+ -------
2708
+ DatasetType
2709
+ Dataset with variables renamed to standard names
2710
+ """
2711
+ variables_dict: dict[Hashable, str] = {v.short_name: v.standard_name for v in variables}
2712
+ name_dict = {var: variables_dict[var] for var in ds.data_vars if var in variables_dict}
2713
+ return ds.rename(name_dict)
2714
+
2715
+
2716
+ def originates_from_ecmwf(met: MetDataset | MetDataArray) -> bool:
2717
+ """Check if data appears to have originated from an ECMWF source.
2718
+
2719
+ .. versionadded:: 0.27.0
2720
+
2721
+ Experimental. Implementation is brittle.
2722
+
2723
+ Parameters
2724
+ ----------
2725
+ met : MetDataset | MetDataArray
2726
+ Dataset or array to inspect.
2727
+
2728
+ Returns
2729
+ -------
2730
+ bool
2731
+ True if data appears to be derived from an ECMWF source.
2732
+
2733
+ See Also
2734
+ --------
2735
+ - :class:`ERA5`
2736
+ - :class:`HRES`
2737
+
2738
+ """
2739
+ if isinstance(met, MetDataset):
2740
+ try:
2741
+ return met.provider_attr == "ECMWF"
2742
+ except KeyError:
2743
+ pass
2744
+ return "ecmwf" in met.attrs.get("history", "")
2745
+
2746
+
2747
+ def _load(hash: str, cachestore: CacheStore, chunks: dict[str, int]) -> xr.Dataset:
2748
+ """Load xarray data from hash.
2749
+
2750
+ Parameters
2751
+ ----------
2752
+ hash : str
2753
+ Description
2754
+ cachestore : CacheStore
2755
+ Description
2756
+ chunks : dict[str, int]
2757
+ Description
2758
+
2759
+ Returns
2760
+ -------
2761
+ xr.Dataset
2762
+ Description
2763
+ """
2764
+ disk_path = cachestore.get(f"{hash}*.nc")
2765
+ return xr.open_mfdataset(disk_path, chunks=chunks)
2766
+
2767
+
2768
+ def _add_vertical_coords(data: XArrayType) -> XArrayType:
2769
+ """Add "air_pressure" and "altitude" coordinates to data.
2770
+
2771
+ .. versionchanged:: 0.52.1
2772
+ Ensure that the ``dtype`` of the additional vertical coordinates agree
2773
+ with the ``dtype`` of the underlying gridded data.
2774
+ """
2775
+
2776
+ data["level"].attrs.update(units="hPa", long_name="Pressure", positive="down")
2777
+
2778
+ # XXX: use the dtype of the data to determine the precision of these coordinates
2779
+ # There are two competing conventions here:
2780
+ # - coordinate data should be float64
2781
+ # - gridded data is typically float32
2782
+ # - air_pressure and altitude often play both roles
2783
+ # It is more important for air_pressure and altitude to be grid-aligned than to be
2784
+ # coordinate-aligned, so we use the dtype of the data to determine the precision of
2785
+ # these coordinates
2786
+ dtype = (
2787
+ np.result_type(*data.data_vars.values(), np.float32)
2788
+ if isinstance(data, xr.Dataset)
2789
+ else data.dtype
2790
+ )
2791
+
2792
+ level = data["level"].values
2793
+
2794
+ if "air_pressure" not in data.coords:
2795
+ data = data.assign_coords(air_pressure=("level", level * 100.0))
2796
+ data.coords["air_pressure"].attrs.update(
2797
+ standard_name=AirPressure.standard_name,
2798
+ long_name=AirPressure.long_name,
2799
+ units=AirPressure.units,
2800
+ )
2801
+ if data.coords["air_pressure"].dtype != dtype:
2802
+ data.coords["air_pressure"] = data.coords["air_pressure"].astype(dtype, copy=False)
2803
+
2804
+ if "altitude" not in data.coords:
2805
+ data = data.assign_coords(altitude=("level", units.pl_to_m(level)))
2806
+ data.coords["altitude"].attrs.update(
2807
+ standard_name=Altitude.standard_name,
2808
+ long_name=Altitude.long_name,
2809
+ units=Altitude.units,
2810
+ )
2811
+ if data.coords["altitude"].dtype != dtype:
2812
+ data.coords["altitude"] = data.coords["altitude"].astype(dtype, copy=False)
2813
+
2814
+ return data
2815
+
2816
+
2817
+ def _lowmem_boundscheck(time: npt.NDArray[np.datetime64], da: xr.DataArray) -> None:
2818
+ """Extra bounds check required with low-memory interpolation strategy.
2819
+
2820
+ Because the main loop in `_interp_lowmem` processes points between time steps
2821
+ in gridded data, it will never encounter points that are out-of-bounds in time
2822
+ and may fail to produce requested out-of-bounds errors.
2823
+ """
2824
+ da_time = da["time"].to_numpy()
2825
+ if not np.all((time >= da_time.min()) & (time <= da_time.max())):
2826
+ axis = da.get_axis_num("time")
2827
+ msg = f"One of the requested xi is out of bounds in dimension {axis}"
2828
+ raise ValueError(msg)
2829
+
2830
+
2831
+ def _lowmem_masks(
2832
+ time: npt.NDArray[np.datetime64], t_met: npt.NDArray[np.datetime64]
2833
+ ) -> Generator[npt.NDArray[np.bool_], None, None]:
2834
+ """Generate sequence of masks for low-memory interpolation."""
2835
+ t_met_max = t_met.max()
2836
+ t_met_min = t_met.min()
2837
+ inbounds = (time >= t_met_min) & (time <= t_met_max)
2838
+ if not np.any(inbounds):
2839
+ return
2840
+
2841
+ earliest = np.nanmin(time)
2842
+ istart = 0 if earliest < t_met_min else np.flatnonzero(t_met <= earliest).max()
2843
+ latest = np.nanmax(time)
2844
+ iend = t_met.size - 1 if latest > t_met_max else np.flatnonzero(t_met >= latest).min()
2845
+ if istart == iend:
2846
+ yield inbounds
2847
+ return
2848
+
2849
+ # Sequence of masks covers elements in time in the interval [t_met[istart], t_met[iend]].
2850
+ # The first iteration masks elements in the interval [t_met[istart], t_met[istart+1]]
2851
+ # (inclusive of both endpoints).
2852
+ # Subsequent iterations mask elements in the interval (t_met[i], t_met[i+1]]
2853
+ # (inclusive of right endpoint only).
2854
+ for i in range(istart, iend):
2855
+ mask = ((time >= t_met[i]) if i == istart else (time > t_met[i])) & (time <= t_met[i + 1])
2856
+ if np.any(mask):
2857
+ yield mask
2858
+
2859
+
2860
+ def maybe_downselect_mds(
2861
+ big_mds: MetDataset,
2862
+ little_mds: MetDataset | None,
2863
+ t0: np.datetime64,
2864
+ t1: np.datetime64,
2865
+ ) -> MetDataset:
2866
+ """Possibly downselect ``big_mds`` in the time domain to cover ``[t0, t1]``.
2867
+
2868
+ If possible, ``little_mds`` is recycled to avoid re-loading data.
2869
+
2870
+ This implementation assumes ``t0 <= t1``, but this is not enforced.
2871
+
2872
+ If ``little_mds`` already covers the time range, it is returned as-is.
2873
+
2874
+ If ``big_mds`` doesn't cover the time range, no error is raised.
2875
+
2876
+ Parameters
2877
+ ----------
2878
+ big_mds : MetDataset
2879
+ Larger MetDataset
2880
+ little_mds : MetDataset | None
2881
+ Smaller MetDataset. This is assumed to be a subset of ``big_mds``,
2882
+ though the implementation may work if this is not the case.
2883
+ t0, t1 : np.datetime64
2884
+ Time range to cover
2885
+
2886
+ Returns
2887
+ -------
2888
+ MetDataset
2889
+ MetDataset covering the time range ``[t0, t1]`` comprised of data from
2890
+ ``little_mds`` when possible, otherwise from ``big_mds``.
2891
+ """
2892
+ if little_mds is None:
2893
+ big_time = big_mds.indexes["time"].values
2894
+ i0 = np.searchsorted(big_time, t0, side="right").item()
2895
+ i0 = max(0, i0 - 1)
2896
+ i1 = np.searchsorted(big_time, t1, side="left").item()
2897
+ i1 = min(i1 + 1, big_time.size)
2898
+ return MetDataset._from_fastpath(big_mds.data.isel(time=slice(i0, i1)))
2899
+
2900
+ little_time = little_mds.indexes["time"].values
2901
+ if t0 >= little_time[0] and t1 <= little_time[-1]:
2902
+ return little_mds
2903
+
2904
+ big_time = big_mds.indexes["time"].values
2905
+ i0 = np.searchsorted(big_time, t0, side="right").item()
2906
+ i0 = max(0, i0 - 1)
2907
+ i1 = np.searchsorted(big_time, t1, side="left").item()
2908
+ i1 = min(i1 + 1, big_time.size)
2909
+ big_ds = big_mds.data.isel(time=slice(i0, i1))
2910
+ big_time = big_ds._indexes["time"].index.values # type: ignore[attr-defined]
2911
+
2912
+ # Select exactly the times in big_ds that are not in little_ds
2913
+ _, little_indices, big_indices = np.intersect1d(
2914
+ little_time, big_time, assume_unique=True, return_indices=True
2915
+ )
2916
+ little_ds = little_mds.data.isel(time=little_indices)
2917
+ filt = np.ones_like(big_time, dtype=bool)
2918
+ filt[big_indices] = False
2919
+ big_ds = big_ds.isel(time=filt)
2920
+
2921
+ # Manually load relevant parts of big_ds into memory before xr.concat
2922
+ # It appears that without this, xr.concat will forget the in-memory
2923
+ # arrays in little_ds
2924
+ for var, da in little_ds.items():
2925
+ if da._in_memory:
2926
+ da2 = big_ds[var]
2927
+ if not da2._in_memory:
2928
+ da2.load()
2929
+
2930
+ ds = xr.concat([little_ds, big_ds], dim="time")
2931
+ if not ds._indexes["time"].index.is_monotonic_increasing: # type: ignore[attr-defined]
2932
+ # Rarely would we enter this: t0 would have to be before the first
2933
+ # time in little_mds, and the various advection-based models generally
2934
+ # proceed forward in time.
2935
+ ds = ds.sortby("time")
2936
+ return MetDataset._from_fastpath(ds)