pycontrails 0.58.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pycontrails might be problematic. Click here for more details.

Files changed (122) hide show
  1. pycontrails/__init__.py +70 -0
  2. pycontrails/_version.py +34 -0
  3. pycontrails/core/__init__.py +30 -0
  4. pycontrails/core/aircraft_performance.py +679 -0
  5. pycontrails/core/airports.py +228 -0
  6. pycontrails/core/cache.py +889 -0
  7. pycontrails/core/coordinates.py +174 -0
  8. pycontrails/core/fleet.py +483 -0
  9. pycontrails/core/flight.py +2185 -0
  10. pycontrails/core/flightplan.py +228 -0
  11. pycontrails/core/fuel.py +140 -0
  12. pycontrails/core/interpolation.py +702 -0
  13. pycontrails/core/met.py +2931 -0
  14. pycontrails/core/met_var.py +387 -0
  15. pycontrails/core/models.py +1321 -0
  16. pycontrails/core/polygon.py +549 -0
  17. pycontrails/core/rgi_cython.cp314-win_amd64.pyd +0 -0
  18. pycontrails/core/vector.py +2249 -0
  19. pycontrails/datalib/__init__.py +12 -0
  20. pycontrails/datalib/_met_utils/metsource.py +746 -0
  21. pycontrails/datalib/ecmwf/__init__.py +73 -0
  22. pycontrails/datalib/ecmwf/arco_era5.py +345 -0
  23. pycontrails/datalib/ecmwf/common.py +114 -0
  24. pycontrails/datalib/ecmwf/era5.py +554 -0
  25. pycontrails/datalib/ecmwf/era5_model_level.py +490 -0
  26. pycontrails/datalib/ecmwf/hres.py +804 -0
  27. pycontrails/datalib/ecmwf/hres_model_level.py +466 -0
  28. pycontrails/datalib/ecmwf/ifs.py +287 -0
  29. pycontrails/datalib/ecmwf/model_levels.py +435 -0
  30. pycontrails/datalib/ecmwf/static/model_level_dataframe_v20240418.csv +139 -0
  31. pycontrails/datalib/ecmwf/variables.py +268 -0
  32. pycontrails/datalib/geo_utils.py +261 -0
  33. pycontrails/datalib/gfs/__init__.py +28 -0
  34. pycontrails/datalib/gfs/gfs.py +656 -0
  35. pycontrails/datalib/gfs/variables.py +104 -0
  36. pycontrails/datalib/goes.py +757 -0
  37. pycontrails/datalib/himawari/__init__.py +27 -0
  38. pycontrails/datalib/himawari/header_struct.py +266 -0
  39. pycontrails/datalib/himawari/himawari.py +667 -0
  40. pycontrails/datalib/landsat.py +589 -0
  41. pycontrails/datalib/leo_utils/__init__.py +5 -0
  42. pycontrails/datalib/leo_utils/correction.py +266 -0
  43. pycontrails/datalib/leo_utils/landsat_metadata.py +300 -0
  44. pycontrails/datalib/leo_utils/search.py +250 -0
  45. pycontrails/datalib/leo_utils/sentinel_metadata.py +748 -0
  46. pycontrails/datalib/leo_utils/static/bq_roi_query.sql +6 -0
  47. pycontrails/datalib/leo_utils/vis.py +59 -0
  48. pycontrails/datalib/sentinel.py +650 -0
  49. pycontrails/datalib/spire/__init__.py +5 -0
  50. pycontrails/datalib/spire/exceptions.py +62 -0
  51. pycontrails/datalib/spire/spire.py +604 -0
  52. pycontrails/ext/bada.py +42 -0
  53. pycontrails/ext/cirium.py +14 -0
  54. pycontrails/ext/empirical_grid.py +140 -0
  55. pycontrails/ext/synthetic_flight.py +431 -0
  56. pycontrails/models/__init__.py +1 -0
  57. pycontrails/models/accf.py +425 -0
  58. pycontrails/models/apcemm/__init__.py +8 -0
  59. pycontrails/models/apcemm/apcemm.py +983 -0
  60. pycontrails/models/apcemm/inputs.py +226 -0
  61. pycontrails/models/apcemm/static/apcemm_yaml_template.yaml +183 -0
  62. pycontrails/models/apcemm/utils.py +437 -0
  63. pycontrails/models/cocip/__init__.py +29 -0
  64. pycontrails/models/cocip/cocip.py +2742 -0
  65. pycontrails/models/cocip/cocip_params.py +305 -0
  66. pycontrails/models/cocip/cocip_uncertainty.py +291 -0
  67. pycontrails/models/cocip/contrail_properties.py +1530 -0
  68. pycontrails/models/cocip/output_formats.py +2270 -0
  69. pycontrails/models/cocip/radiative_forcing.py +1260 -0
  70. pycontrails/models/cocip/radiative_heating.py +520 -0
  71. pycontrails/models/cocip/unterstrasser_wake_vortex.py +508 -0
  72. pycontrails/models/cocip/wake_vortex.py +396 -0
  73. pycontrails/models/cocip/wind_shear.py +120 -0
  74. pycontrails/models/cocipgrid/__init__.py +9 -0
  75. pycontrails/models/cocipgrid/cocip_grid.py +2552 -0
  76. pycontrails/models/cocipgrid/cocip_grid_params.py +138 -0
  77. pycontrails/models/dry_advection.py +602 -0
  78. pycontrails/models/emissions/__init__.py +21 -0
  79. pycontrails/models/emissions/black_carbon.py +599 -0
  80. pycontrails/models/emissions/emissions.py +1353 -0
  81. pycontrails/models/emissions/ffm2.py +336 -0
  82. pycontrails/models/emissions/static/default-engine-uids.csv +239 -0
  83. pycontrails/models/emissions/static/edb-gaseous-v29b-engines.csv +596 -0
  84. pycontrails/models/emissions/static/edb-nvpm-v29b-engines.csv +215 -0
  85. pycontrails/models/extended_k15.py +1327 -0
  86. pycontrails/models/humidity_scaling/__init__.py +37 -0
  87. pycontrails/models/humidity_scaling/humidity_scaling.py +1075 -0
  88. pycontrails/models/humidity_scaling/quantiles/era5-model-level-quantiles.pq +0 -0
  89. pycontrails/models/humidity_scaling/quantiles/era5-pressure-level-quantiles.pq +0 -0
  90. pycontrails/models/issr.py +210 -0
  91. pycontrails/models/pcc.py +326 -0
  92. pycontrails/models/pcr.py +154 -0
  93. pycontrails/models/ps_model/__init__.py +18 -0
  94. pycontrails/models/ps_model/ps_aircraft_params.py +381 -0
  95. pycontrails/models/ps_model/ps_grid.py +701 -0
  96. pycontrails/models/ps_model/ps_model.py +1000 -0
  97. pycontrails/models/ps_model/ps_operational_limits.py +525 -0
  98. pycontrails/models/ps_model/static/ps-aircraft-params-20250328.csv +69 -0
  99. pycontrails/models/ps_model/static/ps-synonym-list-20250328.csv +104 -0
  100. pycontrails/models/sac.py +442 -0
  101. pycontrails/models/tau_cirrus.py +183 -0
  102. pycontrails/physics/__init__.py +1 -0
  103. pycontrails/physics/constants.py +117 -0
  104. pycontrails/physics/geo.py +1138 -0
  105. pycontrails/physics/jet.py +968 -0
  106. pycontrails/physics/static/iata-cargo-load-factors-20250221.csv +74 -0
  107. pycontrails/physics/static/iata-passenger-load-factors-20250221.csv +74 -0
  108. pycontrails/physics/thermo.py +551 -0
  109. pycontrails/physics/units.py +472 -0
  110. pycontrails/py.typed +0 -0
  111. pycontrails/utils/__init__.py +1 -0
  112. pycontrails/utils/dependencies.py +66 -0
  113. pycontrails/utils/iteration.py +13 -0
  114. pycontrails/utils/json.py +187 -0
  115. pycontrails/utils/temp.py +50 -0
  116. pycontrails/utils/types.py +163 -0
  117. pycontrails-0.58.0.dist-info/METADATA +180 -0
  118. pycontrails-0.58.0.dist-info/RECORD +122 -0
  119. pycontrails-0.58.0.dist-info/WHEEL +5 -0
  120. pycontrails-0.58.0.dist-info/licenses/LICENSE +178 -0
  121. pycontrails-0.58.0.dist-info/licenses/NOTICE +43 -0
  122. pycontrails-0.58.0.dist-info/top_level.txt +3 -0
@@ -0,0 +1,702 @@
1
+ """Interpolation utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import dataclasses
6
+ import logging
7
+ from typing import Any, Literal, overload
8
+
9
+ import numpy as np
10
+ import numpy.typing as npt
11
+ import scipy.interpolate
12
+ import xarray as xr
13
+
14
+ from pycontrails.core import rgi_cython # type: ignore[attr-defined]
15
+
16
+ # ------------------------------------------------------------------------------
17
+ # Multidimensional interpolation
18
+ # ------------------------------------------------------------------------------
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class PycontrailsRegularGridInterpolator(scipy.interpolate.RegularGridInterpolator):
25
+ """Support for performant interpolation over a regular grid.
26
+
27
+ This class is a thin wrapper around the
28
+ :class:`scipy.interpolate.RegularGridInterpolator` in order to make typical
29
+ ``pycontrails`` linear interpolation use-cases more performant:
30
+
31
+ #. Avoid ``RegularGridInterpolator`` constructor validation when `method="linear"`.
32
+ In :func:`interp`, parameters are carefully crafted to fit into the intended form,
33
+ thereby making validation unnecessary.
34
+ #. Override the :meth:`_evaluate_linear` method with a faster implementation. See
35
+ the :meth:`_evaluate_linear` docstring for more information.
36
+
37
+ **This class should not be used directly. Instead, use the ``interp`` function.**
38
+
39
+ .. versionchanged:: 0.40.0
40
+
41
+ The :meth:`_evaluate_linear` method now uses a Cython implementation. The dtype
42
+ of the output is now consistent with the dtype of the underlying :attr:`values`
43
+
44
+ .. versionchanged:: 0.58.0
45
+
46
+ Any ``method`` other than ``"linear"`` now uses the
47
+ :class:`scipy.interpolate.RegularGridInterpolator` implementation. This
48
+ allows for greater flexibility in the ``method`` parameter.
49
+
50
+ Parameters
51
+ ----------
52
+ points : tuple[npt.NDArray[np.floating], ...]
53
+ Coordinates of the grid points.
54
+ values : npt.NDArray[np.floating]
55
+ Grid values. The shape of this array must be compatible with the
56
+ coordinates.
57
+ method : str
58
+ Passed into :class:`scipy.interpolate.RegularGridInterpolator`
59
+ bounds_error : bool
60
+ Passed into :class:`scipy.interpolate.RegularGridInterpolator`
61
+ fill_value : float | np.float64 | None
62
+ Passed into :class:`scipy.interpolate.RegularGridInterpolator`
63
+
64
+ See Also
65
+ --------
66
+ scipy.interpolate.RegularGridInterpolator
67
+ interp
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ points: tuple[npt.NDArray[np.floating], ...],
73
+ values: npt.NDArray[np.floating],
74
+ *,
75
+ method: str,
76
+ bounds_error: bool,
77
+ fill_value: float | np.float64 | None,
78
+ ) -> None:
79
+ if method != "linear" or values.dtype not in (np.float32, np.float64):
80
+ # Slow path: use parent class
81
+ super().__init__(
82
+ points,
83
+ values,
84
+ method=method,
85
+ bounds_error=bounds_error,
86
+ fill_value=fill_value,
87
+ )
88
+ return
89
+
90
+ # Fast path: no validation
91
+ self.grid = points
92
+ self.values = values
93
+ self.method = method
94
+ self.bounds_error = bounds_error
95
+ self.fill_value = fill_value
96
+ self._spline = None # XXX: setting private attribute on RGI
97
+
98
+ def _prepare_xi_simple(self, xi: npt.NDArray[np.floating]) -> npt.NDArray[np.bool_]:
99
+ """Run looser version of :meth:`_prepare_xi`.
100
+
101
+ Parameters
102
+ ----------
103
+ xi : npt.NDArray[np.floating]
104
+ Points at which to interpolate.
105
+
106
+ Returns
107
+ -------
108
+ npt.NDArray[np.bool_]
109
+ A 1-dimensional Boolean array indicating which points are out of bounds.
110
+ If ``bounds_error`` is ``True``, this will be all ``False``.
111
+ """
112
+
113
+ if self.bounds_error:
114
+ for i, p in enumerate(xi.T):
115
+ g0 = self.grid[i][0]
116
+ g1 = self.grid[i][-1]
117
+ if not (np.all(p >= g0) and np.all(p <= g1)):
118
+ msg = f"One of the requested xi is out of bounds in dimension {i}"
119
+ raise ValueError(msg)
120
+
121
+ return np.zeros(xi.shape[0], dtype=bool)
122
+
123
+ return self._find_out_of_bounds(xi.T) # XXX: calling private method on RGI
124
+
125
+ def __call__(
126
+ self, xi: npt.NDArray[np.floating], method: str | None = None
127
+ ) -> npt.NDArray[np.floating]:
128
+ """Evaluate the interpolator at the given points.
129
+
130
+ Parameters
131
+ ----------
132
+ xi : npt.NDArray[np.floating]
133
+ Points at which to interpolate. Must have shape ``(n, ndim)``, where
134
+ ``ndim`` is the number of dimensions of the interpolator.
135
+ method : str | None
136
+ Override the :attr:`method` to keep parity with the base class.
137
+
138
+ Returns
139
+ -------
140
+ npt.NDArray[np.floating]
141
+ Interpolated values. Has shape ``(n,)``. When computing linear interpolation,
142
+ the dtype is the same as the :attr:`values` array.
143
+ """
144
+
145
+ method = method or self.method
146
+ if method != "linear":
147
+ return super().__call__(xi, method)
148
+
149
+ out_of_bounds = self._prepare_xi_simple(xi)
150
+ xi_indices, norm_distances = self._find_indices(xi.T) # XXX: calling private method on RGI
151
+
152
+ out = self._evaluate_linear(xi_indices, norm_distances)
153
+ return self._set_out_of_bounds(out, out_of_bounds)
154
+
155
+ def _set_out_of_bounds(
156
+ self,
157
+ out: npt.NDArray[np.floating],
158
+ out_of_bounds: npt.NDArray[np.bool_],
159
+ ) -> npt.NDArray[np.floating]:
160
+ """Set out-of-bounds values to the fill value.
161
+
162
+ Parameters
163
+ ----------
164
+ out : npt.NDArray[np.floating]
165
+ Values from interpolation. This is modified in-place.
166
+ out_of_bounds : npt.NDArray[np.bool_]
167
+ A 1-dimensional Boolean array indicating which points are out of bounds.
168
+
169
+ Returns
170
+ -------
171
+ out : npt.NDArray[np.floating]
172
+ A reference to the ``out`` array.
173
+ """
174
+ if self.fill_value is not None and np.any(out_of_bounds):
175
+ out[out_of_bounds] = self.fill_value
176
+
177
+ return out
178
+
179
+ def _evaluate_linear(
180
+ self,
181
+ indices: npt.NDArray[np.int64],
182
+ norm_distances: npt.NDArray[np.floating],
183
+ ) -> npt.NDArray[np.floating]:
184
+ """Evaluate the interpolator using linear interpolation.
185
+
186
+ This is a faster alternative to
187
+ :meth:`scipy.interpolate.RegularGridInterpolator._evaluate_linear`.
188
+
189
+ .. versionadded:: 0.24
190
+
191
+ .. versionchanged:: 0.40.0
192
+
193
+ Use Cython routines for evaluating the interpolation when the
194
+ dimension is 1, 2, 3, or 4.
195
+
196
+ Parameters
197
+ ----------
198
+ indices : npt.NDArray[np.int64]
199
+ Indices of the grid points to the left of the interpolation points.
200
+ Has shape ``(ndim, n_points)``.
201
+ norm_distances : npt.NDArray[np.floating]
202
+ Normalized distances between the interpolation points and the grid
203
+ points to the left. Has shape ``(ndim, n_points)``.
204
+
205
+ Returns
206
+ -------
207
+ npt.NDArray[np.floating]
208
+ Interpolated values with shape ``(n_points,)`` and the same dtype as
209
+ the :attr:`values` array.
210
+ """
211
+ # Let scipy "slow" implementation deal with high-dimensional grids
212
+ if indices.shape[0] > 4:
213
+ return super()._evaluate_linear(indices, norm_distances)
214
+
215
+ # Squeeze as much as possible
216
+ # Our cython implementation requires non-degenerate arrays
217
+ non_degen = tuple(s > 1 for s in self.values.shape)
218
+ values = self.values.squeeze()
219
+ indices = indices[non_degen, :]
220
+ norm_distances = norm_distances[non_degen, :]
221
+
222
+ ndim, n_points = indices.shape
223
+ out = np.empty(n_points, dtype=self.values.dtype)
224
+
225
+ if ndim == 4:
226
+ return rgi_cython.evaluate_linear_4d(values, indices, norm_distances, out)
227
+
228
+ if ndim == 3:
229
+ return rgi_cython.evaluate_linear_3d(values, indices, norm_distances, out)
230
+
231
+ if ndim == 2:
232
+ return rgi_cython.evaluate_linear_2d(values, indices, norm_distances, out)
233
+
234
+ if ndim == 1:
235
+ # np.interp could be better ... although that may also promote the dtype
236
+ # 1-d view is required for evaluate_linear_1d
237
+ return rgi_cython.evaluate_linear_1d(values, indices[0, :], norm_distances[0, :], out)
238
+
239
+ msg = f"Invalid number of dimensions: {ndim}"
240
+ raise ValueError(msg)
241
+
242
+
243
+ def _floatize_time(
244
+ time: npt.NDArray[np.datetime64], offset: np.datetime64
245
+ ) -> npt.NDArray[np.floating]:
246
+ """Convert an array of ``np.datetime64`` to an array of ``np.float64``.
247
+
248
+ In calls to :class:`scipy.interpolate.RegularGridInterpolator`, it's critical
249
+ that every coordinate be of same type. This creates complications: spatial
250
+ coordinates are float-like, whereas time coordinates are datetime-like. In
251
+ particular, it is not possible to cast an ``np.datetime64`` to a float
252
+ without losing information. In practice, this is not problematic because
253
+ ``np.float64`` has plenty of precision. Previously, this was more of an issue
254
+ because we used ``np.float32``.
255
+
256
+ This function uses a fixed time resolution (1 millisecond) to convert the time
257
+ coordinate to a float-like coordinate. The time resolution is taken to avoid
258
+ losing too much information for the time scales we encounter.
259
+
260
+ Care is taken to ensure "nat" values are converted to "nan".
261
+
262
+ Note that ``xarray`` also must confront this issue. They take a similar approach
263
+ in :func:`xarray.core.missing._floatize_x`. See
264
+ https://github.com/pydata/xarray/blob/d4db16699f30ad1dc3e6861601247abf4ac96567/xarray/core/missing.py#L572
265
+
266
+ .. versionchanged:: 0.40.0
267
+
268
+ No longer allow the option of converting to ``np.float32``. No longer
269
+ floor the time values to the preceding millisecond.
270
+
271
+ Parameters
272
+ ----------
273
+ time : npt.NDArray[np.datetime64]
274
+ Array of ``np.datetime64`` values.
275
+ offset : np.datetime64
276
+ The offset to subtract from ``time``.
277
+
278
+ Returns
279
+ -------
280
+ npt.NDArray[np.floating]
281
+ The number of milliseconds since ``offset``.
282
+ """
283
+ delta = time - offset
284
+ resolution = np.timedelta64(1, "ms")
285
+ return delta / resolution
286
+
287
+
288
+ def _localize(da: xr.DataArray, coords: dict[str, np.ndarray]) -> xr.DataArray:
289
+ """Clip ``da`` to the smallest bounding box that contains all of ``coords``.
290
+
291
+ Roughly follows approach taken by :func:`xarray.core.missing._localize`. See
292
+ https://github.com/pydata/xarray/blob/56f05c37924071eb4712479d47432aafd4dce38b/xarray/core/missing.py#L557
293
+
294
+ Parameters
295
+ ----------
296
+ da : xr.DataArray
297
+ DataArray to clip.
298
+ coords : dict[str, np.ndarray]
299
+ Coordinates to clip to.
300
+
301
+ Returns
302
+ -------
303
+ xr.DataArray
304
+ Clipped :class:`xarray.DataArray`. Has the same dimensions as the input ``da``.
305
+ In particular, each dimension of the returned DataArray is a slice of the
306
+ corresponding dimension of the input ``da``.
307
+ """
308
+ indexes: dict[str, Any] = {}
309
+ for dim, arr in coords.items():
310
+ dim_vals = da[dim].values
311
+
312
+ # Skip single level
313
+ if dim == "level" and dim_vals.size == 1 and dim_vals.item() == -1:
314
+ continue
315
+
316
+ # Create slice
317
+ minval = np.nanmin(arr)
318
+ maxval = np.nanmax(arr)
319
+ imin = np.searchsorted(dim_vals, minval, side="right") - 1
320
+ imin = max(0, imin)
321
+ imax = np.searchsorted(dim_vals, maxval, side="left") + 1
322
+ indexes[dim] = slice(imin, imax)
323
+
324
+ # Logging
325
+ n_in_bounds = np.sum((arr >= minval) & (arr <= maxval))
326
+ logger.debug(
327
+ "Interpolation in bounds along dimension %s: %d/%d",
328
+ dim,
329
+ n_in_bounds,
330
+ arr.size,
331
+ )
332
+
333
+ return da.isel(**indexes)
334
+
335
+
336
+ @overload
337
+ def interp(
338
+ longitude: npt.NDArray[np.floating],
339
+ latitude: npt.NDArray[np.floating],
340
+ level: npt.NDArray[np.floating],
341
+ time: npt.NDArray[np.datetime64],
342
+ da: xr.DataArray,
343
+ method: str,
344
+ bounds_error: bool,
345
+ fill_value: float | np.float64 | None,
346
+ localize: bool,
347
+ *,
348
+ indices: RGIArtifacts | None = ...,
349
+ return_indices: Literal[False] = ...,
350
+ ) -> npt.NDArray[np.floating]: ...
351
+
352
+
353
+ @overload
354
+ def interp(
355
+ longitude: npt.NDArray[np.floating],
356
+ latitude: npt.NDArray[np.floating],
357
+ level: npt.NDArray[np.floating],
358
+ time: npt.NDArray[np.datetime64],
359
+ da: xr.DataArray,
360
+ method: str,
361
+ bounds_error: bool,
362
+ fill_value: float | np.float64 | None,
363
+ localize: bool,
364
+ *,
365
+ indices: RGIArtifacts | None = ...,
366
+ return_indices: Literal[True],
367
+ ) -> tuple[npt.NDArray[np.floating], RGIArtifacts]: ...
368
+
369
+
370
+ @overload
371
+ def interp(
372
+ longitude: npt.NDArray[np.floating],
373
+ latitude: npt.NDArray[np.floating],
374
+ level: npt.NDArray[np.floating],
375
+ time: npt.NDArray[np.datetime64],
376
+ da: xr.DataArray,
377
+ method: str,
378
+ bounds_error: bool,
379
+ fill_value: float | np.float64 | None,
380
+ localize: bool,
381
+ *,
382
+ indices: RGIArtifacts | None = ...,
383
+ return_indices: bool = ...,
384
+ ) -> npt.NDArray[np.floating] | tuple[npt.NDArray[np.floating], RGIArtifacts]: ...
385
+
386
+
387
+ def interp(
388
+ longitude: npt.NDArray[np.floating],
389
+ latitude: npt.NDArray[np.floating],
390
+ level: npt.NDArray[np.floating],
391
+ time: npt.NDArray[np.datetime64],
392
+ da: xr.DataArray,
393
+ method: str,
394
+ bounds_error: bool,
395
+ fill_value: float | np.float64 | None,
396
+ localize: bool,
397
+ *,
398
+ indices: RGIArtifacts | None = None,
399
+ return_indices: bool = False,
400
+ ) -> npt.NDArray[np.floating] | tuple[npt.NDArray[np.floating], RGIArtifacts]:
401
+ """Interpolate over a grid with ``localize`` option.
402
+
403
+ .. versionchanged:: 0.25.6
404
+
405
+ Utilize scipy 1.9 upgrades to remove singleton dimensions.
406
+
407
+ .. versionchanged:: 0.26.0
408
+
409
+ Include ``indices`` and ``return_indices`` experimental parameters.
410
+ Currently, nan values in ``longitude``, ``latitude``, ``level``, or ``time``
411
+ are always propagated through to the output, regardless of ``bounds_error``.
412
+ In other words, a ``ValueError`` for an out of bounds coordinate is only raised
413
+ if a non-nan value is out of bounds.
414
+
415
+ .. versionchanged:: 0.40.0
416
+
417
+ When ``return_indices`` is True, an instance of :class:`RGIArtifacts`
418
+ is used to store the indices artifacts.
419
+
420
+ Parameters
421
+ ----------
422
+ longitude, latitude, level, time : np.ndarray
423
+ Coordinates of points to be interpolated. These parameters have the same
424
+ meaning as ``x`` in analogy with :func:`numpy.interp`. All four of these
425
+ arrays must be 1 dimensional of the same size.
426
+ da : xr.DataArray
427
+ Gridded data interpolated over. Must adhere to ``MetDataArray`` conventions.
428
+ In particular, the dimensions of ``da`` must be ``longitude``, ``latitude``,
429
+ ``level``, and ``time``. The three spatial dimensions must be monotonically
430
+ increasing with ``float64`` dtype. The ``time`` dimension must be
431
+ monotonically increasing with ``datetime64`` dtype.
432
+ Assumed to be cheap to load into memory (:attr:`xr.DataArray.values` is
433
+ used without hesitation).
434
+ method : str
435
+ Passed into :class:`scipy.interpolate.RegularGridInterpolator`.
436
+ bounds_error : bool
437
+ Passed into :class:`scipy.interpolate.RegularGridInterpolator`.
438
+ fill_value : float | np.float64 | None
439
+ Passed into :class:`scipy.interpolate.RegularGridInterpolator`.
440
+ localize : bool
441
+ If True, clip ``da`` to the smallest bounding box that contains all of
442
+ ``coords``.
443
+ indices : tuple | None, optional
444
+ Experimental. Provide intermediate artifacts computed by
445
+ :meth:``scipy.interpolate.RegularGridInterpolator._find_indices`
446
+ to avoid redundant computation. If known and provided, this can speed
447
+ up interpolation by avoiding an unnecessary call to ``_find_indices``.
448
+ By default, None. Must be used precisely.
449
+ return_indices : bool, optional
450
+ If True, return output of :meth:`scipy.interpolate.RegularGridInterpolator._find_indices`
451
+ in addition to interpolated values.
452
+
453
+ Returns
454
+ -------
455
+ npt.NDArray[np.floating] | tuple[npt.NDArray[np.floating], RGIArtifacts]
456
+ Interpolated values with same size as ``longitude``. If ``return_indices``
457
+ is True, return intermediate indices artifacts as well.
458
+
459
+ See Also
460
+ --------
461
+ pycontrails.MetDataArray.interpolate
462
+ scipy.interpolate.interpn
463
+ scipy.interpolate.RegularGridInterpolator
464
+ """
465
+ if localize:
466
+ coords = {"longitude": longitude, "latitude": latitude, "level": level, "time": time}
467
+ da = _localize(da, coords)
468
+
469
+ indexes = da._indexes
470
+ x = indexes["longitude"].index.values # type: ignore[attr-defined]
471
+ y = indexes["latitude"].index.values # type: ignore[attr-defined]
472
+ z = indexes["level"].index.values # type: ignore[attr-defined]
473
+ if any(v.dtype != np.float64 for v in (x, y, z)):
474
+ msg = "da must have float64 dtype for longitude, latitude, and level coordinates"
475
+ raise ValueError(msg)
476
+
477
+ # Convert t and time to float64
478
+ t = indexes["time"].index.values # type: ignore[attr-defined]
479
+ offset = t[0]
480
+ t = _floatize_time(t, offset)
481
+
482
+ single_level = z.size == 1 and z.item() == -1.0
483
+ points: tuple[npt.NDArray[np.floating], ...]
484
+ if single_level:
485
+ values = da.values.squeeze(axis=2)
486
+ points = x, y, t
487
+ else:
488
+ values = da.values
489
+ points = x, y, z, t
490
+
491
+ interp_ = PycontrailsRegularGridInterpolator(
492
+ points=points,
493
+ values=values,
494
+ method=method,
495
+ bounds_error=bounds_error,
496
+ fill_value=fill_value,
497
+ )
498
+
499
+ if indices is None:
500
+ xi = _buildxi(longitude, latitude, level, time, offset, single_level)
501
+ if return_indices:
502
+ out, indices = _linear_interp_with_indices(interp_, xi, localize, None)
503
+ return out, indices
504
+ return interp_(xi)
505
+
506
+ out, indices = _linear_interp_with_indices(interp_, None, localize, indices)
507
+ if return_indices:
508
+ return out, indices
509
+ return out
510
+
511
+
512
+ def _buildxi(
513
+ longitude: npt.NDArray[np.floating],
514
+ latitude: npt.NDArray[np.floating],
515
+ level: npt.NDArray[np.floating],
516
+ time: npt.NDArray[np.datetime64],
517
+ offset: np.datetime64,
518
+ single_level: bool,
519
+ ) -> npt.NDArray[np.floating]:
520
+ """Build the input array for interpolation.
521
+
522
+ The implementation below achieves the same result as the following::
523
+
524
+ np.stack([longitude, latitude, level, time_float], axis=1])
525
+
526
+ This implementation is slightly faster than the above.
527
+ """
528
+
529
+ time_float = _floatize_time(time, offset)
530
+
531
+ ndim = 3 if single_level else 4
532
+ shape = longitude.size, ndim
533
+ xi = np.empty(shape, dtype=np.float64)
534
+
535
+ xi[:, 0] = longitude
536
+ xi[:, 1] = latitude
537
+ if not single_level:
538
+ xi[:, 2] = level
539
+ xi[:, -1] = time_float
540
+
541
+ return xi
542
+
543
+
544
+ def _linear_interp_with_indices(
545
+ interp: PycontrailsRegularGridInterpolator,
546
+ xi: npt.NDArray[np.floating] | None,
547
+ localize: bool,
548
+ indices: RGIArtifacts | None,
549
+ ) -> tuple[npt.NDArray[np.floating], RGIArtifacts]:
550
+ if interp.method != "linear":
551
+ msg = "Parameter 'indices' is only supported for 'method=linear'"
552
+ raise ValueError(msg)
553
+ if localize:
554
+ msg = "Parameter 'indices' is only supported for 'localize=False'"
555
+ raise ValueError(msg)
556
+
557
+ if indices is None:
558
+ assert xi is not None, "xi must be provided if indices is None"
559
+ out_of_bounds = interp._prepare_xi_simple(xi)
560
+ xi_indices, norm_distances = interp._find_indices(xi.T)
561
+ indices = RGIArtifacts(xi_indices, norm_distances, out_of_bounds)
562
+
563
+ out = interp._evaluate_linear(indices.xi_indices, indices.norm_distances)
564
+ out = interp._set_out_of_bounds(out, indices.out_of_bounds)
565
+ return out, indices
566
+
567
+
568
+ @dataclasses.dataclass
569
+ class RGIArtifacts:
570
+ """An interface to intermediate RGI interpolation artifacts."""
571
+
572
+ xi_indices: npt.NDArray[np.int64]
573
+ norm_distances: npt.NDArray[np.floating]
574
+ out_of_bounds: npt.NDArray[np.bool_]
575
+
576
+
577
+ # ------------------------------------------------------------------------------
578
+ # 1 dimensional interpolation
579
+ # ------------------------------------------------------------------------------
580
+
581
+
582
+ class EmissionsProfileInterpolator:
583
+ """Support for interpolating a profile on a linear or logarithmic scale.
584
+
585
+ This class simply wraps :func:`numpy.interp` with fixed values for the
586
+ ``xp`` and ``fp`` arguments. Unlike :class:`xarray.DataArray` interpolation,
587
+ the :func:`numpy.interp` automatically clips values outside the range of the
588
+ ``xp`` array.
589
+
590
+ Parameters
591
+ ----------
592
+ xp : npt.NDArray[np.floating]
593
+ Array of x-values. These must be strictly increasing and free from
594
+ any nan values. Passed to :func:`numpy.interp`.
595
+ fp : npt.NDArray[np.floating]
596
+ Array of y-values. Passed to :func:`numpy.interp`.
597
+ drop_duplicates : bool, optional
598
+ Whether to drop duplicate values in ``xp``. Defaults to ``True``.
599
+
600
+ Examples
601
+ --------
602
+ >>> xp = np.array([3, 7, 10, 30], dtype=float)
603
+ >>> fp = np.array([0.1, 0.2, 0.3, 0.4], dtype=float)
604
+ >>> epi = EmissionsProfileInterpolator(xp, fp)
605
+ >>> # Interpolate a single value
606
+ >>> epi.interp(5)
607
+ np.float64(0.150000...)
608
+
609
+ >>> # Interpolate a single value on a logarithmic scale
610
+ >>> epi.log_interp(5)
611
+ np.float64(1.105171...)
612
+
613
+ >>> # Demonstrate speed up compared with xarray.DataArray interpolation
614
+ >>> import time, xarray as xr
615
+ >>> da = xr.DataArray(fp, dims=["x"], coords={"x": xp})
616
+
617
+ >>> inputs = [np.random.uniform(0, 31, size=200) for _ in range(1000)]
618
+ >>> t0 = time.perf_counter()
619
+ >>> xr_out = [da.interp(x=x.clip(3, 30)).values for x in inputs]
620
+ >>> t1 = time.perf_counter()
621
+ >>> np_out = [epi.interp(x) for x in inputs]
622
+ >>> t2 = time.perf_counter()
623
+ >>> assert np.allclose(xr_out, np_out)
624
+
625
+ >>> # We see a 100 fold speed up (more like 500x faster, but we don't
626
+ >>> # want the test to fail!)
627
+ >>> assert t2 - t1 < (t1 - t0) / 100
628
+ """
629
+
630
+ def __init__(
631
+ self,
632
+ xp: npt.NDArray[np.floating],
633
+ fp: npt.NDArray[np.floating],
634
+ drop_duplicates: bool = True,
635
+ ) -> None:
636
+ if drop_duplicates:
637
+ # Using np.diff to detect duplicates ... this assumes xp is sorted.
638
+ # If xp is not sorted, an ValueError will be raised in _validate
639
+ mask = np.abs(np.diff(xp, prepend=np.inf)) < 1e-15 # small tolerance
640
+ xp = xp[~mask]
641
+ fp = fp[~mask]
642
+
643
+ self.xp = np.asarray(xp)
644
+ self.fp = np.asarray(fp)
645
+ self._validate()
646
+
647
+ def __repr__(self) -> str:
648
+ return f"{self.__class__.__name__}(xp={self.xp}, fp={self.fp})"
649
+
650
+ def _validate(self) -> None:
651
+ if not len(self.xp):
652
+ msg = "xp must not be empty"
653
+ raise ValueError(msg)
654
+ if len(self.xp) != len(self.fp):
655
+ msg = "xp and fp must have the same length"
656
+ raise ValueError(msg)
657
+ if not np.all(np.diff(self.xp) > 0.0):
658
+ msg = "xp must be strictly increasing"
659
+ raise ValueError(msg)
660
+ if np.any(np.isnan(self.xp)):
661
+ msg = "xp must not contain nan values"
662
+ raise ValueError(msg)
663
+
664
+ def interp(self, x: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
665
+ """Interpolate x against xp and fp.
666
+
667
+ Parameters
668
+ ----------
669
+ x : npt.NDArray[np.floating]
670
+ Array of x-values to interpolate.
671
+
672
+ Returns
673
+ -------
674
+ npt.NDArray[np.floating]
675
+ Array of interpolated y-values arising from the x-values. The ``dtype`` of
676
+ the output array is the same as the ``dtype`` of ``x``.
677
+ """
678
+ # Need to explicitly cast back to x.dtype
679
+ # https://github.com/numpy/numpy/issues/11214
680
+ x = np.asarray(x)
681
+ dtype = np.result_type(x, np.float32)
682
+ return np.interp(x, self.xp, self.fp).astype(dtype)
683
+
684
+ def log_interp(self, x: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
685
+ """Interpolate x against xp and fp on a logarithmic scale.
686
+
687
+ This method composes the following three functions.
688
+ 1. :func:`numpy.log`
689
+ 2. :meth:`interp`
690
+ 3. :func:`numpy.exp`
691
+
692
+ Parameters
693
+ ----------
694
+ x : npt.NDArray[np.floating]
695
+ Array of x-values to interpolate.
696
+
697
+ Returns
698
+ -------
699
+ npt.NDArray[np.floating]
700
+ Array of interpolated y-values arising from the x-values.
701
+ """
702
+ return np.exp(self.interp(np.log(x)))