pycontrails 0.49.3__cp310-cp310-macosx_10_9_x86_64.whl → 0.49.5__cp310-cp310-macosx_10_9_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pycontrails might be problematic. Click here for more details.

Files changed (29) hide show
  1. pycontrails/_version.py +2 -2
  2. pycontrails/core/datalib.py +1 -1
  3. pycontrails/core/flight.py +11 -11
  4. pycontrails/core/interpolation.py +29 -19
  5. pycontrails/core/met.py +192 -104
  6. pycontrails/core/models.py +29 -15
  7. pycontrails/core/rgi_cython.cpython-310-darwin.so +0 -0
  8. pycontrails/core/vector.py +14 -15
  9. pycontrails/datalib/gfs/gfs.py +1 -1
  10. pycontrails/datalib/spire/spire.py +23 -19
  11. pycontrails/ext/synthetic_flight.py +3 -1
  12. pycontrails/models/accf.py +6 -4
  13. pycontrails/models/cocip/cocip.py +48 -18
  14. pycontrails/models/cocip/cocip_params.py +13 -10
  15. pycontrails/models/cocip/output_formats.py +62 -52
  16. pycontrails/models/cocipgrid/cocip_grid.py +459 -275
  17. pycontrails/models/cocipgrid/cocip_grid_params.py +12 -18
  18. pycontrails/models/emissions/ffm2.py +10 -8
  19. pycontrails/models/pcc.py +1 -1
  20. pycontrails/models/ps_model/ps_aircraft_params.py +1 -1
  21. pycontrails/models/ps_model/static/{ps-aircraft-params-20231117.csv → ps-aircraft-params-20240209.csv} +12 -3
  22. pycontrails/utils/json.py +12 -10
  23. {pycontrails-0.49.3.dist-info → pycontrails-0.49.5.dist-info}/METADATA +2 -2
  24. {pycontrails-0.49.3.dist-info → pycontrails-0.49.5.dist-info}/RECORD +28 -29
  25. pycontrails/models/cocipgrid/cocip_time_handling.py +0 -342
  26. {pycontrails-0.49.3.dist-info → pycontrails-0.49.5.dist-info}/LICENSE +0 -0
  27. {pycontrails-0.49.3.dist-info → pycontrails-0.49.5.dist-info}/NOTICE +0 -0
  28. {pycontrails-0.49.3.dist-info → pycontrails-0.49.5.dist-info}/WHEEL +0 -0
  29. {pycontrails-0.49.3.dist-info → pycontrails-0.49.5.dist-info}/top_level.txt +0 -0
pycontrails/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.49.3'
16
- __version_tuple__ = version_tuple = (0, 49, 3)
15
+ __version__ = version = '0.49.5'
16
+ __version_tuple__ = version_tuple = (0, 49, 5)
@@ -353,7 +353,7 @@ class MetDataSource(abc.ABC):
353
353
  self,
354
354
  time: TimeInput | None,
355
355
  variables: VariableInput,
356
- pressure_levels: PressureLevelInput = [-1],
356
+ pressure_levels: PressureLevelInput = -1,
357
357
  paths: str | list[str] | pathlib.Path | list[pathlib.Path] | None = None,
358
358
  grid: float | None = None,
359
359
  **kwargs: Any,
@@ -936,8 +936,7 @@ class Flight(GeoVectorDataset):
936
936
 
937
937
  # Remove original index if requested
938
938
  if not keep_original_index:
939
- filt = df.index.isin(t)
940
- df = df.loc[filt]
939
+ df = df.loc[t]
941
940
 
942
941
  # finally reset index
943
942
  df = df.reset_index()
@@ -1200,11 +1199,13 @@ class Flight(GeoVectorDataset):
1200
1199
  dict[str, Any]
1201
1200
  Python representation of geojson FeatureCollection
1202
1201
  """
1203
- points = _return_linestring({
1204
- "longitude": self["longitude"],
1205
- "latitude": self["latitude"],
1206
- "altitude": self.altitude,
1207
- })
1202
+ points = _return_linestring(
1203
+ {
1204
+ "longitude": self["longitude"],
1205
+ "latitude": self["latitude"],
1206
+ "altitude": self.altitude,
1207
+ }
1208
+ )
1208
1209
  geometry = {"type": "LineString", "coordinates": points}
1209
1210
  properties = {
1210
1211
  "start_time": self.time_start.isoformat(),
@@ -2041,7 +2042,7 @@ def _resample_to_freq(df: pd.DataFrame, freq: str) -> tuple[pd.DataFrame, pd.Dat
2041
2042
  """Resample a DataFrame to a given frequency.
2042
2043
 
2043
2044
  This function is used to resample a DataFrame to a given frequency. The new
2044
- index will include all the original index values and the new esampled-to-freq
2045
+ index will include all the original index values and the new resampled-to-freq
2045
2046
  index values. The "longitude" and "latitude" columns will be linearly interpolated
2046
2047
  to the new index values.
2047
2048
 
@@ -2064,7 +2065,7 @@ def _resample_to_freq(df: pd.DataFrame, freq: str) -> tuple[pd.DataFrame, pd.Dat
2064
2065
  # and the resampled-to-freq index values.
2065
2066
  t0 = df.index[0]
2066
2067
  t1 = df.index[-1]
2067
- t = pd.date_range(t0, t1, freq=freq).floor(freq)
2068
+ t = pd.date_range(t0, t1, freq=freq, name="time").floor(freq)
2068
2069
  if t[0] < t0:
2069
2070
  t = t[1:]
2070
2071
 
@@ -2072,8 +2073,7 @@ def _resample_to_freq(df: pd.DataFrame, freq: str) -> tuple[pd.DataFrame, pd.Dat
2072
2073
  concat_arr = np.unique(concat_arr)
2073
2074
  concat_index = pd.DatetimeIndex(concat_arr, name="time", copy=False)
2074
2075
 
2075
- out = pd.DataFrame(index=concat_index, columns=df.columns, dtype=float)
2076
- out.loc[df.index] = df
2076
+ out = df.reindex(concat_index)
2077
2077
 
2078
2078
  # Linearly interpolate small horizontal gap
2079
2079
  coords = ["longitude", "latitude"]
@@ -66,7 +66,8 @@ class PycontrailsRegularGridInterpolator(scipy.interpolate.RegularGridInterpolat
66
66
  fill_value: float | np.float64 | None,
67
67
  ):
68
68
  if values.dtype not in (np.float32, np.float64):
69
- raise ValueError("values must be a float array")
69
+ msg = f"values must be a float array, not {values.dtype}"
70
+ raise ValueError(msg)
70
71
 
71
72
  self.grid = points
72
73
  self.values = values
@@ -94,7 +95,8 @@ class PycontrailsRegularGridInterpolator(scipy.interpolate.RegularGridInterpolat
94
95
  g0 = self.grid[i][0]
95
96
  g1 = self.grid[i][-1]
96
97
  if not (np.all(p >= g0) and np.all(p <= g1)):
97
- raise ValueError(f"One of the requested xi is out of bounds in dimension {i}")
98
+ msg = f"One of the requested xi is out of bounds in dimension {i}"
99
+ raise ValueError(msg)
98
100
 
99
101
  return np.zeros(xi.shape[0], dtype=bool)
100
102
 
@@ -213,7 +215,8 @@ class PycontrailsRegularGridInterpolator(scipy.interpolate.RegularGridInterpolat
213
215
  # np.interp could be better ... although that may also promote the dtype
214
216
  return rgi_cython.evaluate_linear_1d(values, indices, norm_distances, out)
215
217
 
216
- raise ValueError(f"Invalid number of dimensions: {ndim}")
218
+ msg = f"Invalid number of dimensions: {ndim}"
219
+ raise ValueError(msg)
217
220
 
218
221
 
219
222
  def _floatize_time(
@@ -442,18 +445,16 @@ def interp(
442
445
  coords = {"longitude": longitude, "latitude": latitude, "level": level, "time": time}
443
446
  da = _localize(da, coords)
444
447
 
445
- # Using da.coords.variables is slightly more performant than da["longitude"].values
446
- variables = da.coords.variables
447
- x = variables["longitude"].values
448
- y = variables["latitude"].values
449
- z = variables["level"].values
448
+ indexes = da._indexes
449
+ x = indexes["longitude"].index.to_numpy() # type: ignore[attr-defined]
450
+ y = indexes["latitude"].index.to_numpy() # type: ignore[attr-defined]
451
+ z = indexes["level"].index.to_numpy() # type: ignore[attr-defined]
450
452
  if any(v.dtype != np.float64 for v in (x, y, z)):
451
- raise ValueError(
452
- "da must have float64 dtype for longitude, latitude, and level coordinates"
453
- )
453
+ msg = "da must have float64 dtype for longitude, latitude, and level coordinates"
454
+ raise ValueError(msg)
454
455
 
455
456
  # Convert t and time to float64
456
- t = variables["time"].values
457
+ t = indexes["time"].index.to_numpy() # type: ignore[attr-defined]
457
458
  offset = t[0]
458
459
  t = _floatize_time(t, offset)
459
460
 
@@ -526,9 +527,11 @@ def _linear_interp_with_indices(
526
527
  indices: RGIArtifacts | None,
527
528
  ) -> tuple[npt.NDArray[np.float64], RGIArtifacts]:
528
529
  if interp.method != "linear":
529
- raise ValueError("Parameter 'indices' is only supported for 'method=linear'")
530
+ msg = "Parameter 'indices' is only supported for 'method=linear'"
531
+ raise ValueError(msg)
530
532
  if localize:
531
- raise ValueError("Parameter 'indices' is only supported for 'localize=False'")
533
+ msg = "Parameter 'indices' is only supported for 'localize=False'"
534
+ raise ValueError(msg)
532
535
 
533
536
  if indices is None:
534
537
  assert xi is not None, "xi must be provided if indices is None"
@@ -604,7 +607,10 @@ class EmissionsProfileInterpolator:
604
607
  """
605
608
 
606
609
  def __init__(
607
- self, xp: npt.NDArray[np.float64], fp: npt.NDArray[np.float64], drop_duplicates: bool = True
610
+ self,
611
+ xp: npt.NDArray[np.float64],
612
+ fp: npt.NDArray[np.float64],
613
+ drop_duplicates: bool = True,
608
614
  ) -> None:
609
615
  if drop_duplicates:
610
616
  # Using np.diff to detect duplicates ... this assumes xp is sorted.
@@ -622,13 +628,17 @@ class EmissionsProfileInterpolator:
622
628
 
623
629
  def _validate(self) -> None:
624
630
  if not len(self.xp):
625
- raise ValueError("xp must not be empty")
631
+ msg = "xp must not be empty"
632
+ raise ValueError(msg)
626
633
  if len(self.xp) != len(self.fp):
627
- raise ValueError("xp and fp must have the same length")
634
+ msg = "xp and fp must have the same length"
635
+ raise ValueError(msg)
628
636
  if not np.all(np.diff(self.xp) > 0.0):
629
- raise ValueError("xp must be strictly increasing")
637
+ msg = "xp must be strictly increasing"
638
+ raise ValueError(msg)
630
639
  if np.any(np.isnan(self.xp)):
631
- raise ValueError("xp must not contain nan values")
640
+ msg = "xp must not contain nan values"
641
+ raise ValueError(msg)
632
642
 
633
643
  def interp(self, x: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
634
644
  """Interpolate x against xp and fp.
pycontrails/core/met.py CHANGED
@@ -23,6 +23,7 @@ from typing import (
23
23
 
24
24
  import numpy as np
25
25
  import numpy.typing as npt
26
+ import pandas as pd
26
27
  import xarray as xr
27
28
  from overrides import overrides
28
29
 
@@ -110,7 +111,7 @@ class MetBase(ABC, Generic[XArrayType]):
110
111
  ValueError
111
112
  If longitude values are not contained in the interval [-180, 180].
112
113
  """
113
- longitude = self.variables["longitude"].values
114
+ longitude = self.indexes["longitude"].to_numpy()
114
115
  if longitude.dtype != COORD_DTYPE:
115
116
  raise ValueError(
116
117
  "Longitude values must be of type float64. "
@@ -154,7 +155,7 @@ class MetBase(ABC, Generic[XArrayType]):
154
155
  ValueError
155
156
  If latitude values are not contained in the interval [-90, 90].
156
157
  """
157
- latitude = self.variables["latitude"].values
158
+ latitude = self.indexes["latitude"].to_numpy()
158
159
  if latitude.dtype != COORD_DTYPE:
159
160
  raise ValueError(
160
161
  "Latitude values must be of type float64. "
@@ -181,10 +182,11 @@ class MetBase(ABC, Generic[XArrayType]):
181
182
  ValueError
182
183
  If one of the coordinates is not sorted.
183
184
  """
184
- if not np.all(np.diff(self.variables["time"]) > np.timedelta64(0, "ns")):
185
+ indexes = self.indexes
186
+ if not np.all(np.diff(indexes["time"]) > np.timedelta64(0, "ns")):
185
187
  raise ValueError("Coordinate `time` not sorted. Initiate with `copy=True`.")
186
188
  for coord in self.dim_order[:3]: # exclude time, the 4th dimension
187
- if not np.all(np.diff(self.variables[coord]) > 0.0):
189
+ if not np.all(np.diff(indexes[coord]) > 0.0):
188
190
  raise ValueError(f"Coordinate '{coord}' not sorted. Initiate with 'copy=True'.")
189
191
 
190
192
  def _validate_transpose(self) -> None:
@@ -251,8 +253,9 @@ class MetBase(ABC, Generic[XArrayType]):
251
253
  self._validate_dim_contains_coords()
252
254
 
253
255
  # Ensure spatial coordinates all have dtype COORD_DTYPE
256
+ indexes = self.indexes
254
257
  for coord in ("longitude", "latitude", "level"):
255
- arr = self.variables[coord].values
258
+ arr = indexes[coord].to_numpy()
256
259
  if arr.dtype != COORD_DTYPE:
257
260
  self.data[coord] = arr.astype(COORD_DTYPE)
258
261
 
@@ -265,7 +268,7 @@ class MetBase(ABC, Generic[XArrayType]):
265
268
  if not self.is_wrapped:
266
269
  # Ensure longitude is contained in interval [-180, 180)
267
270
  # If longitude has value at 180, we might not want to shift it?
268
- lon = self.variables["longitude"].values
271
+ lon = self.indexes["longitude"].to_numpy()
269
272
 
270
273
  # This longitude shifting can give rise to precision errors with float32
271
274
  # Only shift if necessary
@@ -288,48 +291,12 @@ class MetBase(ABC, Generic[XArrayType]):
288
291
  # single level data
289
292
  if self.is_single_level:
290
293
  # add level attributes to reflect surface level
291
- self.data["level"].attrs.update(units="", long_name="Single Level")
294
+ level_attrs = self.data["level"].attrs
295
+ if not level_attrs:
296
+ level_attrs.update(units="", long_name="Single Level")
292
297
  return
293
298
 
294
- # pressure level data
295
- level = self.variables["level"].values
296
-
297
- # add pressure level attributes
298
- self.data["level"].attrs.update(units="hPa", long_name="Pressure", positive="down")
299
-
300
- # add altitude and air_pressure
301
-
302
- # XXX: use the dtype of the data to determine the precision of these coordinates
303
- # There are two competing conventions here:
304
- # - coordinate data should be float64
305
- # - gridded data is typically float32
306
- # - air_pressure and altitude often play both roles
307
- # It is more important for air_pressure and altitude to be grid-aligned than to be
308
- # coordinate-aligned, so we use the dtype of the data to determine the precision of
309
- # these coordinates
310
- if isinstance(self.data, xr.Dataset):
311
- dtype = np.result_type(*self.data.data_vars.values(), np.float32)
312
- else:
313
- dtype = self.data.dtype
314
-
315
- level = level.astype(dtype)
316
- air_pressure = level * 100.0
317
- altitude = units.pl_to_m(level)
318
- self.data = self.data.assign_coords({"air_pressure": ("level", air_pressure)})
319
- self.data = self.data.assign_coords({"altitude": ("level", altitude)})
320
-
321
- # add air_pressure units and long name attributes
322
- self.data.coords["air_pressure"].attrs.update(
323
- standard_name=AirPressure.standard_name,
324
- long_name=AirPressure.long_name,
325
- units=AirPressure.units,
326
- )
327
- # add altitude units and long name attributes
328
- self.data.coords["altitude"].attrs.update(
329
- standard_name=Altitude.standard_name,
330
- long_name=Altitude.long_name,
331
- units=Altitude.units,
332
- )
299
+ self.data = _add_vertical_coords(self.data)
333
300
 
334
301
  @property
335
302
  def hash(self) -> str:
@@ -382,29 +349,36 @@ class MetBase(ABC, Generic[XArrayType]):
382
349
  dict[str, np.ndarray]
383
350
  Dictionary of coordinates
384
351
  """
352
+ variables = self.indexes
385
353
  return {
386
- "longitude": self.variables["longitude"].values,
387
- "latitude": self.variables["latitude"].values,
388
- "level": self.variables["level"].values,
389
- "time": self.variables["time"].values,
354
+ "longitude": variables["longitude"].to_numpy(),
355
+ "latitude": variables["latitude"].to_numpy(),
356
+ "level": variables["level"].to_numpy(),
357
+ "time": variables["time"].to_numpy(),
390
358
  }
391
359
 
392
360
  @property
393
- def variables(self) -> dict[Hashable, xr.Variable]:
394
- """Low level access to underlying :attr:`data` variables.
361
+ def variables(self) -> dict[Hashable, pd.Index]:
362
+ """See :attr:`indexes`."""
363
+ warnings.warn(
364
+ "The 'variables' property is deprecated and will be removed in a future release. "
365
+ "Use 'indexes' instead.",
366
+ DeprecationWarning,
367
+ )
368
+ return self.indexes
395
369
 
396
- This method is typically is faster for accessing coordinate variables.
370
+ @property
371
+ def indexes(self) -> dict[Hashable, pd.Index]:
372
+ """Low level access to underlying :attr:`data` indexes.
373
+
374
+ This method is typically is faster for accessing coordinate indexes.
397
375
 
398
376
  .. versionadded:: 0.25.2
399
377
 
400
378
  Returns
401
379
  -------
402
- dict[Hashable, xr.Variable]
403
- Dictionary of variables. The type is actually..
404
-
405
- xarray.core.utils.Frozen[Any, xr.Variable]
406
-
407
- In practice, this behaves like a dictionary.
380
+ dict[Hashable, pd.Index]
381
+ Dictionary of indexes.
408
382
 
409
383
  Examples
410
384
  --------
@@ -414,16 +388,14 @@ class MetBase(ABC, Generic[XArrayType]):
414
388
  >>> levels = [200, 300]
415
389
  >>> era5 = ERA5(times, variables, levels)
416
390
  >>> mds = era5.open_metdataset()
417
- >>> mds.variables["level"].values # faster access than mds.data["level"]
391
+ >>> mds.indexes["level"].to_numpy()
418
392
  array([200., 300.])
419
393
 
420
394
  >>> mda = mds["air_temperature"]
421
- >>> mda.variables["level"].values # faster access than mda.data["level"]
395
+ >>> mda.indexes["level"].to_numpy()
422
396
  array([200., 300.])
423
397
  """
424
- if isinstance(self.data, xr.Dataset):
425
- return self.data.variables # type: ignore[return-value]
426
- return self.data.coords.variables
398
+ return {k: v.index for k, v in self.data._indexes.items()} # type: ignore[attr-defined]
427
399
 
428
400
  @property
429
401
  def is_wrapped(self) -> bool:
@@ -449,7 +421,7 @@ class MetBase(ABC, Generic[XArrayType]):
449
421
  --------
450
422
  :func:`pycontrails.physics.geo.advect_longitude`
451
423
  """
452
- longitude = self.variables["longitude"].values
424
+ longitude = self.indexes["longitude"].to_numpy()
453
425
  return _is_wrapped(longitude)
454
426
 
455
427
  @property
@@ -464,7 +436,7 @@ class MetBase(ABC, Generic[XArrayType]):
464
436
  bool
465
437
  If instance contains single level data.
466
438
  """
467
- level = self.variables["level"].values
439
+ level = self.indexes["level"].to_numpy()
468
440
  return len(level) == 1 and level[0] == -1
469
441
 
470
442
  @abstractmethod
@@ -569,6 +541,82 @@ class MetBase(ABC, Generic[XArrayType]):
569
541
  """
570
542
  return _is_zarr(self.data)
571
543
 
544
+ def downselect_met(
545
+ self,
546
+ met: MetDataType,
547
+ *,
548
+ longitude_buffer: tuple[float, float] = (0.0, 0.0),
549
+ latitude_buffer: tuple[float, float] = (0.0, 0.0),
550
+ level_buffer: tuple[float, float] = (0.0, 0.0),
551
+ time_buffer: tuple[np.timedelta64, np.timedelta64] = (
552
+ np.timedelta64(0, "h"),
553
+ np.timedelta64(0, "h"),
554
+ ),
555
+ copy: bool = True,
556
+ ) -> MetDataType:
557
+ """Downselect ``met`` to encompass a spatiotemporal region of the data.
558
+
559
+ .. warning::
560
+
561
+ This method is analogous to :meth:`GeoVectorDataset.downselect_met`.
562
+ It does not change the instance data, but instead operates on the
563
+ ``met`` input. This method is different from :meth:`downselect` which
564
+ operates on the instance data.
565
+
566
+ Parameters
567
+ ----------
568
+ met : MetDataset | MetDataArray
569
+ MetDataset or MetDataArray to downselect.
570
+ longitude_buffer : tuple[float, float], optional
571
+ Extend longitude domain past by ``longitude_buffer[0]`` on the low side
572
+ and ``longitude_buffer[1]`` on the high side.
573
+ Units must be the same as class coordinates.
574
+ Defaults to ``(0, 0)`` degrees.
575
+ latitude_buffer : tuple[float, float], optional
576
+ Extend latitude domain past by ``latitude_buffer[0]`` on the low side
577
+ and ``latitude_buffer[1]`` on the high side.
578
+ Units must be the same as class coordinates.
579
+ Defaults to ``(0, 0)`` degrees.
580
+ level_buffer : tuple[float, float], optional
581
+ Extend level domain past by ``level_buffer[0]`` on the low side
582
+ and ``level_buffer[1]`` on the high side.
583
+ Units must be the same as class coordinates.
584
+ Defaults to ``(0, 0)`` [:math:`hPa`].
585
+ time_buffer : tuple[np.timedelta64, np.timedelta64], optional
586
+ Extend time domain past by ``time_buffer[0]`` on the low side
587
+ and ``time_buffer[1]`` on the high side.
588
+ Units must be the same as class coordinates.
589
+ Defaults to ``(np.timedelta64(0, "h"), np.timedelta64(0, "h"))``.
590
+ copy : bool
591
+ If returned object is a copy or view of the original. True by default.
592
+
593
+ Returns
594
+ -------
595
+ MetDataset | MetDataArray
596
+ Copy of downselected MetDataset or MetDataArray.
597
+ """
598
+ indexes = self.indexes
599
+ lon = indexes["longitude"].to_numpy()
600
+ lat = indexes["latitude"].to_numpy()
601
+ level = indexes["level"].to_numpy()
602
+ time = indexes["time"].to_numpy()
603
+
604
+ vector = vector_module.GeoVectorDataset(
605
+ longitude=[lon.min(), lon.max()],
606
+ latitude=[lat.min(), lat.max()],
607
+ level=[level.min(), level.max()],
608
+ time=[time.min(), time.max()],
609
+ )
610
+
611
+ return vector.downselect_met(
612
+ met,
613
+ longitude_buffer=longitude_buffer,
614
+ latitude_buffer=latitude_buffer,
615
+ level_buffer=level_buffer,
616
+ time_buffer=time_buffer,
617
+ copy=copy,
618
+ )
619
+
572
620
 
573
621
  class MetDataset(MetBase):
574
622
  """Meteorological dataset with multiple variables.
@@ -668,6 +716,8 @@ class MetDataset(MetBase):
668
716
  raise ValueError("Set 'copy=True' when using 'wrap_longitude=True'.")
669
717
  self.data = data
670
718
  self._validate_dims()
719
+ if not self.is_single_level:
720
+ self.data = _add_vertical_coords(self.data)
671
721
 
672
722
  def __getitem__(self, key: Hashable) -> MetDataArray:
673
723
  """Return DataArray of variable ``key`` cast to a :class:`MetDataArray` object.
@@ -1014,22 +1064,22 @@ class MetDataset(MetBase):
1014
1064
  crs EPSG:4326
1015
1065
 
1016
1066
  """
1017
- coords_keys = [str(key) for key in self.data.dims] # str not in Hashable
1018
- coords_vals = [self.variables[key].values for key in coords_keys]
1067
+ coords_keys = self.data.dims
1068
+ variables = self.indexes
1069
+ coords_vals = [variables[key].values for key in coords_keys]
1019
1070
  coords_meshes = np.meshgrid(*coords_vals, indexing="ij")
1020
- raveled_coords = [mesh.ravel() for mesh in coords_meshes]
1071
+ raveled_coords = (mesh.ravel() for mesh in coords_meshes)
1021
1072
  data = dict(zip(coords_keys, raveled_coords))
1022
1073
 
1023
- vector = vector_module.GeoVectorDataset(data, copy=False)
1024
- for key in self:
1074
+ out = vector_module.GeoVectorDataset(data, copy=False)
1075
+ for key, da in self.data.items():
1025
1076
  # The call to .values here will load the data if it is lazy
1026
- vector[key] = self[key].data.values.ravel()
1077
+ out[key] = da.values.ravel() # type: ignore[index]
1027
1078
 
1028
1079
  if transfer_attrs:
1029
- # vector.attrs expects keys to be strings .... we'll get an error
1030
- # if we cannot cast here
1031
- vector.attrs.update({str(k): v for k, v in self.attrs.items()})
1032
- return vector
1080
+ out.attrs.update(self.attrs) # type: ignore[arg-type]
1081
+
1082
+ return out
1033
1083
 
1034
1084
  def _get_pycontrails_attr_template(
1035
1085
  self,
@@ -1157,15 +1207,15 @@ class MetDataset(MetBase):
1157
1207
  >>> met = MetDataset.from_coords(longitude, latitude, level, time)
1158
1208
  >>> met
1159
1209
  MetDataset with data:
1160
- <xarray.Dataset>
1210
+ <xarray.Dataset> Size: 360B
1161
1211
  Dimensions: (longitude: 20, latitude: 20, level: 2, time: 1)
1162
1212
  Coordinates:
1163
- * longitude (longitude) float64 0.0 0.5 1.0 1.5 2.0 ... 8.0 8.5 9.0 9.5
1164
- * latitude (latitude) float64 0.0 0.5 1.0 1.5 2.0 ... 7.5 8.0 8.5 9.0 9.5
1165
- * level (level) float64 250.0 300.0
1166
- * time (time) datetime64[ns] 2019-01-01
1167
- air_pressure (level) float32 2.5e+04 3e+04
1168
- altitude (level) float32 1.036e+04 9.164e+03
1213
+ * longitude (longitude) float64 160B 0.0 0.5 1.0 1.5 ... 8.0 8.5 9.0 9.5
1214
+ * latitude (latitude) float64 160B 0.0 0.5 1.0 1.5 ... 8.0 8.5 9.0 9.5
1215
+ * level (level) float64 16B 250.0 300.0
1216
+ * time (time) datetime64[ns] 8B 2019-01-01
1217
+ air_pressure (level) float32 8B 2.5e+04 3e+04
1218
+ altitude (level) float32 8B 1.036e+04 9.164e+03
1169
1219
  Data variables:
1170
1220
  *empty*
1171
1221
 
@@ -1180,18 +1230,18 @@ class MetDataset(MetBase):
1180
1230
  >>> met["humidity"] = xr.DataArray(np.full(met.shape, 0.5), coords=met.coords)
1181
1231
  >>> met
1182
1232
  MetDataset with data:
1183
- <xarray.Dataset>
1233
+ <xarray.Dataset> Size: 13kB
1184
1234
  Dimensions: (longitude: 20, latitude: 20, level: 2, time: 1)
1185
1235
  Coordinates:
1186
- * longitude (longitude) float64 0.0 0.5 1.0 1.5 2.0 ... 8.0 8.5 9.0 9.5
1187
- * latitude (latitude) float64 0.0 0.5 1.0 1.5 2.0 ... 7.5 8.0 8.5 9.0 9.5
1188
- * level (level) float64 250.0 300.0
1189
- * time (time) datetime64[ns] 2019-01-01
1190
- air_pressure (level) float32 2.5e+04 3e+04
1191
- altitude (level) float32 1.036e+04 9.164e+03
1236
+ * longitude (longitude) float64 160B 0.0 0.5 1.0 1.5 ... 8.0 8.5 9.0 9.5
1237
+ * latitude (latitude) float64 160B 0.0 0.5 1.0 1.5 ... 8.0 8.5 9.0 9.5
1238
+ * level (level) float64 16B 250.0 300.0
1239
+ * time (time) datetime64[ns] 8B 2019-01-01
1240
+ air_pressure (level) float32 8B 2.5e+04 3e+04
1241
+ altitude (level) float32 8B 1.036e+04 9.164e+03
1192
1242
  Data variables:
1193
- temperature (longitude, latitude, level, time) float64 234.5 ... 234.5
1194
- humidity (longitude, latitude, level, time) float64 0.5 0.5 ... 0.5 0.5
1243
+ temperature (longitude, latitude, level, time) float64 6kB 234.5 ... 234.5
1244
+ humidity (longitude, latitude, level, time) float64 6kB 0.5 0.5 ... 0.5
1195
1245
 
1196
1246
  >>> # Convert to a GeoVectorDataset
1197
1247
  >>> vector = met.to_vector()
@@ -1888,8 +1938,9 @@ class MetDataArray(MetBase):
1888
1938
  from pycontrails.core import polygon
1889
1939
 
1890
1940
  # Convert to nested lists of coordinates for GeoJSON representation
1891
- longitude: npt.NDArray[np.float64] = self.variables["longitude"].values
1892
- latitude: npt.NDArray[np.float64] = self.variables["latitude"].values
1941
+ indexes = self.indexes
1942
+ longitude = indexes["longitude"].to_numpy()
1943
+ latitude = indexes["latitude"].to_numpy()
1893
1944
 
1894
1945
  mp = polygon.find_multipolygon(
1895
1946
  arr,
@@ -2097,9 +2148,9 @@ class MetDataArray(MetBase):
2097
2148
  volume = self.data.sel(time=time).values
2098
2149
 
2099
2150
  # convert from array index back to coordinates
2100
- longitude = self.variables["longitude"].values
2101
- latitude = self.variables["latitude"].values
2102
- altitude = units.pl_to_m(self.variables["level"].values)
2151
+ longitude = self.indexes["longitude"].values
2152
+ latitude = self.indexes["latitude"].values
2153
+ altitude = units.pl_to_m(self.indexes["level"].values)
2103
2154
 
2104
2155
  # Pad volume on all axes to close the volumes
2105
2156
  if closed:
@@ -2290,11 +2341,7 @@ def _wrap_longitude(data: XArrayType) -> XArrayType:
2290
2341
  ValueError
2291
2342
  If longitude values are already wrapped.
2292
2343
  """
2293
- if isinstance(data, xr.Dataset):
2294
- lon = data.variables["longitude"].values
2295
- else:
2296
- lon = data.coords.variables["longitude"].values
2297
-
2344
+ lon = data._indexes["longitude"].index.to_numpy() # type: ignore[attr-defined]
2298
2345
  if _is_wrapped(lon):
2299
2346
  raise ValueError("Longitude values are already wrapped")
2300
2347
 
@@ -2359,7 +2406,7 @@ def _extract_2d_arr_and_altitude(
2359
2406
  """
2360
2407
  # Determine level if not specified
2361
2408
  if level is None:
2362
- level_coord = mda.variables["level"].values
2409
+ level_coord = mda.indexes["level"].values
2363
2410
  if len(level_coord) == 1:
2364
2411
  level = level_coord[0]
2365
2412
  else:
@@ -2370,7 +2417,7 @@ def _extract_2d_arr_and_altitude(
2370
2417
 
2371
2418
  # Determine time if not specified
2372
2419
  if time is None:
2373
- time_coord = mda.variables["time"].values
2420
+ time_coord = mda.indexes["time"].values
2374
2421
  if len(time_coord) == 1:
2375
2422
  time = time_coord[0]
2376
2423
  else:
@@ -2533,3 +2580,44 @@ def _load(hash: str, cachestore: CacheStore, chunks: dict[str, int]) -> xr.Datas
2533
2580
  """
2534
2581
  disk_path = cachestore.get(f"{hash}*.nc")
2535
2582
  return xr.open_mfdataset(disk_path, chunks=chunks)
2583
+
2584
+
2585
+ def _add_vertical_coords(data: XArrayType) -> XArrayType:
2586
+ """Add "air_pressure" and "altitude" coordinates to data."""
2587
+
2588
+ data["level"].attrs.update(units="hPa", long_name="Pressure", positive="down")
2589
+
2590
+ coords = data.coords
2591
+ if "air_pressure" in coords and "altitude" in coords:
2592
+ return data
2593
+
2594
+ # XXX: use the dtype of the data to determine the precision of these coordinates
2595
+ # There are two competing conventions here:
2596
+ # - coordinate data should be float64
2597
+ # - gridded data is typically float32
2598
+ # - air_pressure and altitude often play both roles
2599
+ # It is more important for air_pressure and altitude to be grid-aligned than to be
2600
+ # coordinate-aligned, so we use the dtype of the data to determine the precision of
2601
+ # these coordinates
2602
+ if isinstance(data, xr.Dataset):
2603
+ dtype = np.result_type(*data.data_vars.values(), np.float32)
2604
+ else:
2605
+ dtype = data.dtype
2606
+ level = data["level"].values.astype(dtype, copy=False)
2607
+
2608
+ if "air_pressure" not in coords:
2609
+ data = data.assign_coords(air_pressure=("level", level * 100.0))
2610
+ data.coords["air_pressure"].attrs.update(
2611
+ standard_name=AirPressure.standard_name,
2612
+ long_name=AirPressure.long_name,
2613
+ units=AirPressure.units,
2614
+ )
2615
+ if "altitude" not in coords:
2616
+ data = data.assign_coords(altitude=("level", units.pl_to_m(level)))
2617
+ data.coords["altitude"].attrs.update(
2618
+ standard_name=Altitude.standard_name,
2619
+ long_name=Altitude.long_name,
2620
+ units=Altitude.units,
2621
+ )
2622
+
2623
+ return data