axsdb 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- axsdb/core.py +40 -31
- axsdb/factory.py +3 -2
- axsdb/interpolation.py +803 -0
- axsdb/math.py +503 -0
- axsdb/testing/__init__.py +0 -0
- axsdb/testing/fixtures.py +77 -0
- {axsdb-0.0.2.dist-info → axsdb-0.0.3.dist-info}/METADATA +7 -5
- axsdb-0.0.3.dist-info/RECORD +17 -0
- {axsdb-0.0.2.dist-info → axsdb-0.0.3.dist-info}/WHEEL +1 -1
- axsdb-0.0.2.dist-info/RECORD +0 -13
- {axsdb-0.0.2.dist-info → axsdb-0.0.3.dist-info}/entry_points.txt +0 -0
axsdb/interpolation.py
ADDED
|
@@ -0,0 +1,803 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Fast xarray interpolation.
|
|
3
|
+
|
|
4
|
+
This module provides high-performance interpolation functions for xarray
|
|
5
|
+
DataArrays that bypass xarray's built-in interpolation. This is motivated
|
|
6
|
+
by performance regressions in recent xarray versions (see
|
|
7
|
+
https://github.com/pydata/xarray/issues/10683).
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from collections.abc import Hashable
|
|
13
|
+
from typing import Literal
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import xarray as xr
|
|
17
|
+
from scipy.interpolate import interpn
|
|
18
|
+
|
|
19
|
+
from .math import interp1d, lerp, lerp_indices
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _apply_fill_values(
|
|
23
|
+
data: np.ndarray,
|
|
24
|
+
new_coords_arr: np.ndarray,
|
|
25
|
+
old_coords_arr: np.ndarray,
|
|
26
|
+
dim_fill_value: float | tuple[float, float],
|
|
27
|
+
) -> None:
|
|
28
|
+
"""
|
|
29
|
+
Replace NaN values with specified fill values for out-of-bounds points.
|
|
30
|
+
|
|
31
|
+
Modifies data in-place. Assumes NaN has already been set for OOB points
|
|
32
|
+
by the interpolation functions.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
data : ndarray
|
|
37
|
+
Data array with NaN for out-of-bounds points. Modified in-place.
|
|
38
|
+
|
|
39
|
+
new_coords_arr : ndarray
|
|
40
|
+
Query coordinates
|
|
41
|
+
|
|
42
|
+
old_coords_arr : ndarray
|
|
43
|
+
Original grid coordinates
|
|
44
|
+
|
|
45
|
+
dim_fill_value : float or tuple
|
|
46
|
+
Fill value(s) to use. If tuple, (``fill_lower``, ``fill_upper``).
|
|
47
|
+
"""
|
|
48
|
+
oob = (new_coords_arr < old_coords_arr[0]) | (new_coords_arr > old_coords_arr[-1])
|
|
49
|
+
if not oob.any():
|
|
50
|
+
return
|
|
51
|
+
|
|
52
|
+
if isinstance(dim_fill_value, tuple):
|
|
53
|
+
fill_lo, fill_hi = dim_fill_value
|
|
54
|
+
else:
|
|
55
|
+
fill_lo = fill_hi = dim_fill_value
|
|
56
|
+
|
|
57
|
+
lo = new_coords_arr < old_coords_arr[0]
|
|
58
|
+
hi = new_coords_arr > old_coords_arr[-1]
|
|
59
|
+
|
|
60
|
+
# Apply fill values along the last axis
|
|
61
|
+
data[..., lo] = fill_lo
|
|
62
|
+
data[..., hi] = fill_hi
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _decompose_interp_groups(
|
|
66
|
+
interp_specs: list[dict],
|
|
67
|
+
) -> list[list[dict]]:
|
|
68
|
+
"""
|
|
69
|
+
Decompose interpolation specs into groups based on destination dimensions.
|
|
70
|
+
|
|
71
|
+
Groups specs so that dimensions with independent destinations are in
|
|
72
|
+
separate groups, while dimensions sharing a destination are grouped together.
|
|
73
|
+
This enables using scipy.interpn for multi-dimensional interpolation when
|
|
74
|
+
multiple source dimensions map to the same destination dimension.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
interp_specs : list of dict
|
|
79
|
+
List of interpolation specs, each with 'dim', 'arr', 'new_dims', etc.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
list of list of dict
|
|
84
|
+
Groups of specs. Each group can be processed together.
|
|
85
|
+
"""
|
|
86
|
+
if not interp_specs:
|
|
87
|
+
return []
|
|
88
|
+
|
|
89
|
+
groups: list[list[dict]] = []
|
|
90
|
+
current_group: list[dict] = []
|
|
91
|
+
current_dest_dims: set = set()
|
|
92
|
+
|
|
93
|
+
for spec in interp_specs:
|
|
94
|
+
spec_dest_dims = set(spec["new_dims"]) if spec["new_dims"] else {spec["dim"]}
|
|
95
|
+
|
|
96
|
+
if not current_group:
|
|
97
|
+
# First spec starts a new group
|
|
98
|
+
current_group = [spec]
|
|
99
|
+
current_dest_dims = spec_dest_dims
|
|
100
|
+
elif spec_dest_dims & current_dest_dims:
|
|
101
|
+
# Shares destination with current group - add to it
|
|
102
|
+
current_group.append(spec)
|
|
103
|
+
current_dest_dims |= spec_dest_dims
|
|
104
|
+
else:
|
|
105
|
+
# Independent of current group - yield current and start new
|
|
106
|
+
groups.append(current_group)
|
|
107
|
+
current_group = [spec]
|
|
108
|
+
current_dest_dims = spec_dest_dims
|
|
109
|
+
|
|
110
|
+
if current_group:
|
|
111
|
+
groups.append(current_group)
|
|
112
|
+
|
|
113
|
+
return groups
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _check_uniform_bounds(
|
|
117
|
+
group: list[dict],
|
|
118
|
+
bounds_dict: dict[Hashable, Literal["fill", "clamp", "raise"]],
|
|
119
|
+
) -> tuple[bool, Literal["fill", "clamp", "raise"] | None]:
|
|
120
|
+
"""
|
|
121
|
+
Check if all dimensions in a group have uniform bounds mode.
|
|
122
|
+
|
|
123
|
+
Parameters
|
|
124
|
+
----------
|
|
125
|
+
group : list of dict
|
|
126
|
+
Group of interpolation specs.
|
|
127
|
+
bounds_dict : dict
|
|
128
|
+
Mapping from dimension names to bounds modes.
|
|
129
|
+
|
|
130
|
+
Returns
|
|
131
|
+
-------
|
|
132
|
+
is_uniform : bool
|
|
133
|
+
True if all dimensions have the same bounds mode.
|
|
134
|
+
|
|
135
|
+
uniform_mode : {"fill", "clamp", "raise"} or None
|
|
136
|
+
The shared bounds mode, or None if not uniform.
|
|
137
|
+
"""
|
|
138
|
+
if len(group) == 0:
|
|
139
|
+
return False, None
|
|
140
|
+
|
|
141
|
+
# Extract bounds mode for first spec
|
|
142
|
+
first_dim = group[0]["dim"]
|
|
143
|
+
first_mode = bounds_dict.get(first_dim, "fill")
|
|
144
|
+
|
|
145
|
+
# Check if all other specs have same bounds mode
|
|
146
|
+
for spec in group[1:]:
|
|
147
|
+
dim = spec["dim"]
|
|
148
|
+
mode = bounds_dict.get(dim, "fill")
|
|
149
|
+
if mode != first_mode:
|
|
150
|
+
return False, None
|
|
151
|
+
|
|
152
|
+
return True, first_mode
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _should_use_fast_path(
|
|
156
|
+
group: list[dict],
|
|
157
|
+
dims: list[Hashable],
|
|
158
|
+
bounds_dict: dict[Hashable, Literal["fill", "clamp", "raise"]],
|
|
159
|
+
) -> bool:
|
|
160
|
+
"""
|
|
161
|
+
Determine if group qualifies for scipy.interpn fast path.
|
|
162
|
+
|
|
163
|
+
Parameters
|
|
164
|
+
----------
|
|
165
|
+
group : list of dict
|
|
166
|
+
Group of interpolation specs.
|
|
167
|
+
dims : list
|
|
168
|
+
Current dimension names in the data.
|
|
169
|
+
bounds_dict : dict
|
|
170
|
+
Mapping from dimension names to bounds modes.
|
|
171
|
+
|
|
172
|
+
Returns
|
|
173
|
+
-------
|
|
174
|
+
bool
|
|
175
|
+
True if the group qualifies for fast path processing.
|
|
176
|
+
"""
|
|
177
|
+
# Must have multiple specs in group
|
|
178
|
+
if len(group) <= 1:
|
|
179
|
+
return False
|
|
180
|
+
|
|
181
|
+
# All specs must have same non-empty destination dims
|
|
182
|
+
if not group[0]["new_dims"]:
|
|
183
|
+
return False
|
|
184
|
+
|
|
185
|
+
dest_dims = group[0]["new_dims"]
|
|
186
|
+
if not all(spec["new_dims"] == dest_dims for spec in group):
|
|
187
|
+
return False
|
|
188
|
+
|
|
189
|
+
# Destination dims must not already exist (else need pointwise path)
|
|
190
|
+
if set(dest_dims) & set(dims):
|
|
191
|
+
return False
|
|
192
|
+
|
|
193
|
+
# No scalar specs allowed
|
|
194
|
+
if any(spec["is_scalar"] for spec in group):
|
|
195
|
+
return False
|
|
196
|
+
|
|
197
|
+
# All dimensions must have uniform bounds mode
|
|
198
|
+
is_uniform, _ = _check_uniform_bounds(group, bounds_dict)
|
|
199
|
+
if not is_uniform:
|
|
200
|
+
return False
|
|
201
|
+
|
|
202
|
+
return True
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _interp_group_with_interpn(
|
|
206
|
+
data: np.ndarray,
|
|
207
|
+
dims: list[Hashable],
|
|
208
|
+
da: xr.DataArray,
|
|
209
|
+
group: list[dict],
|
|
210
|
+
bounds_mode: Literal["fill", "clamp", "raise"],
|
|
211
|
+
fill_value_dict: dict[Hashable, float | tuple[float, float]],
|
|
212
|
+
) -> tuple[np.ndarray, list[Hashable]]:
|
|
213
|
+
"""
|
|
214
|
+
Interpolate multiple dimensions at once using scipy.interpn.
|
|
215
|
+
|
|
216
|
+
This is used when multiple source dimensions share a common destination
|
|
217
|
+
dimension (e.g., t, p, x_H2O all mapping to z) and have uniform bounds
|
|
218
|
+
mode. Using interpn is much faster than sequential 1D interpolation.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
data : np.ndarray
|
|
223
|
+
Current data array.
|
|
224
|
+
dims : list
|
|
225
|
+
Current dimension names corresponding to data axes.
|
|
226
|
+
da : xr.DataArray
|
|
227
|
+
Original DataArray (for coordinate grids).
|
|
228
|
+
group : list of dict
|
|
229
|
+
Group of specs to interpolate together.
|
|
230
|
+
bounds_mode : {"fill", "clamp", "raise"}
|
|
231
|
+
Uniform bounds mode for all dimensions in group.
|
|
232
|
+
fill_value_dict : dict
|
|
233
|
+
Per-dimension fill values (used for "fill" mode).
|
|
234
|
+
|
|
235
|
+
Returns
|
|
236
|
+
-------
|
|
237
|
+
data : np.ndarray
|
|
238
|
+
Interpolated data.
|
|
239
|
+
dims : list
|
|
240
|
+
Updated dimension names.
|
|
241
|
+
"""
|
|
242
|
+
# Identify source dims and destination dims
|
|
243
|
+
src_dims = [spec["dim"] for spec in group]
|
|
244
|
+
dest_dims = group[0]["new_dims"] # All specs in group share destination dims
|
|
245
|
+
|
|
246
|
+
# Build grid points tuple for interpn (order matters!)
|
|
247
|
+
grid_points = tuple(da.coords[dim].values for dim in src_dims)
|
|
248
|
+
|
|
249
|
+
# Build query points array: shape (dest_size, n_src_dims)
|
|
250
|
+
query_arrays = [spec["arr"] for spec in group]
|
|
251
|
+
xi = np.stack(query_arrays, axis=-1)
|
|
252
|
+
|
|
253
|
+
# Find axes of source dims in current data layout
|
|
254
|
+
src_axes = [dims.index(dim) for dim in src_dims]
|
|
255
|
+
other_axes = [i for i in range(len(dims)) if i not in src_axes]
|
|
256
|
+
other_dims = [dims[i] for i in other_axes]
|
|
257
|
+
|
|
258
|
+
# Transpose: src_dims first, then other dims (batch dims at the end for interpn)
|
|
259
|
+
perm = src_axes + other_axes
|
|
260
|
+
data_transposed = data.transpose(perm)
|
|
261
|
+
|
|
262
|
+
# Handle bounds mode uniformly
|
|
263
|
+
if bounds_mode == "raise":
|
|
264
|
+
# Pre-validate all dimensions before interpolation
|
|
265
|
+
for i, (dim, grid) in enumerate(zip(src_dims, grid_points)):
|
|
266
|
+
xi_dim = xi[..., i]
|
|
267
|
+
xi_min, xi_max = xi_dim.min(), xi_dim.max()
|
|
268
|
+
grid_min, grid_max = grid[0], grid[-1]
|
|
269
|
+
|
|
270
|
+
if xi_min < grid_min or xi_max > grid_max:
|
|
271
|
+
raise ValueError(
|
|
272
|
+
f"Interpolation error on dimension {dim!r}: "
|
|
273
|
+
f"Query points out of bounds.\n"
|
|
274
|
+
f" Grid range: [{grid_min}, {grid_max}]\n"
|
|
275
|
+
f" Query range: [{xi_min}, {xi_max}]\n"
|
|
276
|
+
f" Bounds mode: raise"
|
|
277
|
+
)
|
|
278
|
+
# All checks passed - call interpn with bounds_error=False
|
|
279
|
+
result = interpn(
|
|
280
|
+
grid_points,
|
|
281
|
+
data_transposed,
|
|
282
|
+
xi,
|
|
283
|
+
method="linear",
|
|
284
|
+
bounds_error=False,
|
|
285
|
+
fill_value=np.nan,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
elif bounds_mode == "clamp":
|
|
289
|
+
# Clamp query points to grid bounds
|
|
290
|
+
xi_clamped = xi.copy()
|
|
291
|
+
for i, grid in enumerate(grid_points):
|
|
292
|
+
xi_clamped[..., i] = np.clip(xi_clamped[..., i], grid[0], grid[-1])
|
|
293
|
+
result = interpn(
|
|
294
|
+
grid_points,
|
|
295
|
+
data_transposed,
|
|
296
|
+
xi_clamped,
|
|
297
|
+
method="linear",
|
|
298
|
+
bounds_error=False,
|
|
299
|
+
fill_value=np.nan,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
else: # bounds_mode == "fill"
|
|
303
|
+
# Call interpn with bounds_error=False to get NaN for OOB
|
|
304
|
+
result = interpn(
|
|
305
|
+
grid_points,
|
|
306
|
+
data_transposed,
|
|
307
|
+
xi,
|
|
308
|
+
method="linear",
|
|
309
|
+
bounds_error=False,
|
|
310
|
+
fill_value=np.nan,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Post-process per-dimension fill values
|
|
314
|
+
# For each dimension, replace NaN where that dim was OOB
|
|
315
|
+
for i, (dim, grid) in enumerate(zip(src_dims, grid_points)):
|
|
316
|
+
dim_fill_value = fill_value_dict.get(dim, np.nan)
|
|
317
|
+
|
|
318
|
+
# Skip if fill value is already NaN (no-op)
|
|
319
|
+
if isinstance(dim_fill_value, tuple):
|
|
320
|
+
fill_lo, fill_hi = dim_fill_value
|
|
321
|
+
else:
|
|
322
|
+
fill_lo = fill_hi = dim_fill_value
|
|
323
|
+
|
|
324
|
+
if not (np.isnan(fill_lo) and np.isnan(fill_hi)):
|
|
325
|
+
# Identify OOB points for this dimension
|
|
326
|
+
xi_dim = xi[..., i]
|
|
327
|
+
grid_min, grid_max = grid[0], grid[-1]
|
|
328
|
+
|
|
329
|
+
# Apply fill values
|
|
330
|
+
# Note: result shape is (dest_size, *batch_dims)
|
|
331
|
+
oob_lo = xi_dim < grid_min
|
|
332
|
+
oob_hi = xi_dim > grid_max
|
|
333
|
+
|
|
334
|
+
# Broadcast OOB mask to result shape
|
|
335
|
+
# xi_dim shape: (dest_size,), result shape: (dest_size, *batch)
|
|
336
|
+
if oob_lo.any():
|
|
337
|
+
result[oob_lo, ...] = fill_lo
|
|
338
|
+
if oob_hi.any():
|
|
339
|
+
result[oob_hi, ...] = fill_hi
|
|
340
|
+
|
|
341
|
+
# Update dims: replace src_dims with dest_dims
|
|
342
|
+
new_dims = list(dest_dims) + other_dims
|
|
343
|
+
|
|
344
|
+
return result, new_dims
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def interp_dataarray(
|
|
348
|
+
da: xr.DataArray,
|
|
349
|
+
coords: dict[Hashable, float | np.ndarray | xr.DataArray],
|
|
350
|
+
bounds: Literal["fill", "clamp", "raise"]
|
|
351
|
+
| dict[Hashable, Literal["fill", "clamp", "raise"]] = "fill",
|
|
352
|
+
fill_value: float
|
|
353
|
+
| tuple[float, float]
|
|
354
|
+
| dict[Hashable, float | tuple[float, float]] = np.nan,
|
|
355
|
+
) -> xr.DataArray:
|
|
356
|
+
"""
|
|
357
|
+
Fast linear interpolation for xarray DataArrays.
|
|
358
|
+
|
|
359
|
+
This function provides a high-performance alternative to xarray's
|
|
360
|
+
``.interp()`` method for linear interpolation, with the possibility to
|
|
361
|
+
control out-of-bound handling behaviour per dimension.
|
|
362
|
+
|
|
363
|
+
Parameters
|
|
364
|
+
----------
|
|
365
|
+
da : DataArray
|
|
366
|
+
The input DataArray to interpolate.
|
|
367
|
+
|
|
368
|
+
coords : dict
|
|
369
|
+
Mapping from dimension names to new coordinate values. Each value
|
|
370
|
+
can be a scalar, numpy array, or xarray DataArray. Dimensions are
|
|
371
|
+
interpolated sequentially in the order given.
|
|
372
|
+
|
|
373
|
+
bounds : {"fill", "clamp", "raise"} or dict, default: "fill"
|
|
374
|
+
How to handle out-of-bounds query points. Can be a single value
|
|
375
|
+
applied to all dimensions, or a dict mapping dimension names to
|
|
376
|
+
bounds modes. Missing dimensions default to ``"fill"``.
|
|
377
|
+
|
|
378
|
+
* "fill": Use ``fill_value`` for points outside the data range.
|
|
379
|
+
* "clamp": Use the nearest boundary value.
|
|
380
|
+
* "raise": Raise a ValueError if any query point is out of bounds.
|
|
381
|
+
|
|
382
|
+
fill_value : float or tuple or dict, default: np.nan
|
|
383
|
+
Value(s) to use for out-of-bounds points when ``bounds="fill"``.
|
|
384
|
+
Can be:
|
|
385
|
+
|
|
386
|
+
* a single float (used for both lower and upper bounds, all dims);
|
|
387
|
+
* a 2-tuple (``fill_lower``, ``fill_upper``) for all dimensions;
|
|
388
|
+
* a dict mapping dimension names to floats or 2-tuples.
|
|
389
|
+
|
|
390
|
+
Missing dimensions default to ``np.nan``.
|
|
391
|
+
|
|
392
|
+
Returns
|
|
393
|
+
-------
|
|
394
|
+
DataArray
|
|
395
|
+
Interpolated DataArray with the new coordinates. Preserves the
|
|
396
|
+
original DataArray's name and attributes.
|
|
397
|
+
|
|
398
|
+
Raises
|
|
399
|
+
------
|
|
400
|
+
ValueError
|
|
401
|
+
If a dimension in ``coords`` does not exist in the DataArray.
|
|
402
|
+
If ``bounds="raise"`` and any query point is out of bounds.
|
|
403
|
+
The error message includes the dimension name for easier debugging.
|
|
404
|
+
|
|
405
|
+
Notes
|
|
406
|
+
-----
|
|
407
|
+
The function assumes coordinates are sorted in ascending order along
|
|
408
|
+
each interpolation dimension. Results are undefined if this assumption
|
|
409
|
+
is violated.
|
|
410
|
+
|
|
411
|
+
Interpolation is performed sequentially across dimensions in the order
|
|
412
|
+
given in ``coords``. This is equivalent to chained ``.interp()`` calls
|
|
413
|
+
but significantly faster due to the optimized gufunc implementation.
|
|
414
|
+
|
|
415
|
+
Examples
|
|
416
|
+
--------
|
|
417
|
+
Basic multi-dimensional interpolation:
|
|
418
|
+
|
|
419
|
+
>>> import numpy as np
|
|
420
|
+
>>> import xarray as xr
|
|
421
|
+
>>> from axsdb.interpolation import interp_dataarray
|
|
422
|
+
>>>
|
|
423
|
+
>>> # Create sample data
|
|
424
|
+
>>> da = xr.DataArray(
|
|
425
|
+
... np.random.rand(10, 20, 30),
|
|
426
|
+
... dims=["wavelength", "angle", "time"],
|
|
427
|
+
... coords={
|
|
428
|
+
... "wavelength": np.linspace(400, 700, 10),
|
|
429
|
+
... "angle": np.linspace(0, 90, 20),
|
|
430
|
+
... "time": np.linspace(0, 1, 30),
|
|
431
|
+
... },
|
|
432
|
+
... )
|
|
433
|
+
>>>
|
|
434
|
+
>>> # Interpolate to new coordinates
|
|
435
|
+
>>> result = interp_dataarray(
|
|
436
|
+
... da,
|
|
437
|
+
... {
|
|
438
|
+
... "wavelength": np.array([450.0, 550.0, 650.0]),
|
|
439
|
+
... "angle": np.array([30.0, 60.0]),
|
|
440
|
+
... },
|
|
441
|
+
... )
|
|
442
|
+
>>> result.shape
|
|
443
|
+
(3, 2, 30)
|
|
444
|
+
|
|
445
|
+
Different bounds handling per dimension:
|
|
446
|
+
|
|
447
|
+
>>> result = interp_dataarray(
|
|
448
|
+
... da,
|
|
449
|
+
... {"wavelength": new_wavelengths, "angle": new_angles},
|
|
450
|
+
... bounds={"wavelength": "raise", "angle": "clamp"},
|
|
451
|
+
... )
|
|
452
|
+
|
|
453
|
+
Custom fill values per dimension:
|
|
454
|
+
|
|
455
|
+
>>> result = interp_dataarray(
|
|
456
|
+
... da,
|
|
457
|
+
... {"wavelength": new_wavelengths, "angle": new_angles},
|
|
458
|
+
... bounds="fill",
|
|
459
|
+
... fill_value={"wavelength": 0.0, "angle": (-1.0, 1.0)},
|
|
460
|
+
... )
|
|
461
|
+
"""
|
|
462
|
+
# Normalize bounds to dict format
|
|
463
|
+
if isinstance(bounds, str):
|
|
464
|
+
bounds_dict: dict[Hashable, Literal["fill", "clamp", "raise"]] = {
|
|
465
|
+
dim: bounds for dim in coords
|
|
466
|
+
}
|
|
467
|
+
else:
|
|
468
|
+
bounds_dict = dict(bounds)
|
|
469
|
+
for dim in coords:
|
|
470
|
+
if dim not in bounds_dict:
|
|
471
|
+
bounds_dict[dim] = "fill"
|
|
472
|
+
|
|
473
|
+
# Normalize fill_value to dict format
|
|
474
|
+
if isinstance(fill_value, (int, float)) or (
|
|
475
|
+
isinstance(fill_value, tuple) and len(fill_value) == 2
|
|
476
|
+
):
|
|
477
|
+
fill_value_dict: dict[Hashable, float | tuple[float, float]] = {
|
|
478
|
+
dim: fill_value for dim in coords
|
|
479
|
+
}
|
|
480
|
+
else:
|
|
481
|
+
fill_value_dict = dict(fill_value) # type: ignore[arg-type]
|
|
482
|
+
for dim in coords:
|
|
483
|
+
if dim not in fill_value_dict:
|
|
484
|
+
fill_value_dict[dim] = np.nan
|
|
485
|
+
|
|
486
|
+
# --- Pre-process all interpolation targets once ---
|
|
487
|
+
# For each dimension we resolve the new coordinates to a plain numpy
|
|
488
|
+
# array and record the metadata needed for the final DataArray wrap.
|
|
489
|
+
interp_specs: list[dict] = []
|
|
490
|
+
for dim, new_coords in coords.items():
|
|
491
|
+
if dim not in da.dims:
|
|
492
|
+
raise ValueError(
|
|
493
|
+
f"Dimension {dim!r} not found in DataArray. "
|
|
494
|
+
f"Available dimensions: {list(da.dims)}"
|
|
495
|
+
)
|
|
496
|
+
if isinstance(new_coords, xr.DataArray):
|
|
497
|
+
spec = {
|
|
498
|
+
"dim": dim,
|
|
499
|
+
"arr": new_coords.values,
|
|
500
|
+
"new_dims": new_coords.dims,
|
|
501
|
+
"new_coords": {
|
|
502
|
+
k: v for k, v in new_coords.coords.items() if k in new_coords.dims
|
|
503
|
+
},
|
|
504
|
+
"is_scalar": False,
|
|
505
|
+
}
|
|
506
|
+
elif np.isscalar(new_coords):
|
|
507
|
+
spec = {
|
|
508
|
+
"dim": dim,
|
|
509
|
+
"arr": np.array([new_coords]),
|
|
510
|
+
"new_dims": (),
|
|
511
|
+
"new_coords": {},
|
|
512
|
+
"is_scalar": True,
|
|
513
|
+
}
|
|
514
|
+
else:
|
|
515
|
+
arr = np.asarray(new_coords)
|
|
516
|
+
spec = {
|
|
517
|
+
"dim": dim,
|
|
518
|
+
"arr": arr,
|
|
519
|
+
"new_dims": (),
|
|
520
|
+
"new_coords": {},
|
|
521
|
+
"is_scalar": arr.ndim == 0,
|
|
522
|
+
}
|
|
523
|
+
if spec["is_scalar"]:
|
|
524
|
+
spec["arr"] = arr.reshape(1)
|
|
525
|
+
interp_specs.append(spec)
|
|
526
|
+
|
|
527
|
+
# --- Reorder: shrinking dimensions before expanding ones ---
|
|
528
|
+
# Interpolating a dimension that expands (query size > grid size) while
|
|
529
|
+
# other spectator dimensions are still large multiplies work
|
|
530
|
+
# unnecessarily. Processing shrinking dims first keeps the intermediate
|
|
531
|
+
# array small until the expansion happens on a much smaller array.
|
|
532
|
+
#
|
|
533
|
+
# Within expanding dimensions, process LARGER grids first. This minimizes
|
|
534
|
+
# the intermediate array size when introducing a new dimension (e.g. z).
|
|
535
|
+
# The first expander creates (total / grid_size) * query_size elements;
|
|
536
|
+
# larger grid_size means smaller intermediate. Subsequent expanders
|
|
537
|
+
# share the new dimension and reduce the array; processing larger grids
|
|
538
|
+
# first means fewer lerp operations overall.
|
|
539
|
+
def _interp_sort_key(spec: dict) -> tuple:
|
|
540
|
+
grid_size = da.sizes[spec["dim"]]
|
|
541
|
+
query_size = len(spec["arr"])
|
|
542
|
+
# (0, ...) = shrink/same -> first; (1, ...) = expand -> last
|
|
543
|
+
# Within each group, larger grid_size first (hence negative)
|
|
544
|
+
if query_size <= grid_size:
|
|
545
|
+
return (0, -grid_size)
|
|
546
|
+
else:
|
|
547
|
+
return (1, -grid_size)
|
|
548
|
+
|
|
549
|
+
interp_specs.sort(key=_interp_sort_key)
|
|
550
|
+
|
|
551
|
+
# --- Decompose into groups for potential fast path ---
|
|
552
|
+
# Group specs by destination dimensions. When multiple source dims map to
|
|
553
|
+
# the same destination (e.g. t, p, x_H2O all -> z), processing them together
|
|
554
|
+
# with scipy.interpn is much faster than sequential 1D interpolation.
|
|
555
|
+
interp_groups = _decompose_interp_groups(interp_specs)
|
|
556
|
+
|
|
557
|
+
# --- Main interpolation loop on raw numpy arrays ---
|
|
558
|
+
# `dims` tracks the current logical dimension order.
|
|
559
|
+
# `data` is the raw ndarray; axes correspond 1-to-1 with `dims`.
|
|
560
|
+
dims: list[Hashable] = list(da.dims)
|
|
561
|
+
data: np.ndarray = da.values
|
|
562
|
+
|
|
563
|
+
for group in interp_groups:
|
|
564
|
+
# Check if this group can use the fast interpn path:
|
|
565
|
+
# - Multiple specs in the group
|
|
566
|
+
# - All specs have the same non-empty new_dims (shared destination)
|
|
567
|
+
# - The destination dims don't already exist in the current array
|
|
568
|
+
# (if they do, we need the pointwise/shared path)
|
|
569
|
+
# - All dimensions have uniform bounds mode
|
|
570
|
+
use_interpn = _should_use_fast_path(group, dims, bounds_dict)
|
|
571
|
+
|
|
572
|
+
if use_interpn:
|
|
573
|
+
# Fast path: multi-dimensional interpn
|
|
574
|
+
is_uniform, uniform_mode = _check_uniform_bounds(group, bounds_dict)
|
|
575
|
+
data, dims = _interp_group_with_interpn(
|
|
576
|
+
data, dims, da, group, uniform_mode, fill_value_dict
|
|
577
|
+
)
|
|
578
|
+
continue
|
|
579
|
+
|
|
580
|
+
# Sequential path: process each spec one at a time
|
|
581
|
+
for spec in group:
|
|
582
|
+
dim = spec["dim"]
|
|
583
|
+
new_coords_arr: np.ndarray = spec["arr"]
|
|
584
|
+
new_coords_dims: tuple = spec["new_dims"]
|
|
585
|
+
is_scalar: bool = spec["is_scalar"]
|
|
586
|
+
|
|
587
|
+
dim_bounds = bounds_dict.get(dim, "fill")
|
|
588
|
+
dim_fill_value = fill_value_dict.get(dim, np.nan)
|
|
589
|
+
|
|
590
|
+
dim_axis = dims.index(dim)
|
|
591
|
+
old_coords_arr = da.coords[dim].values
|
|
592
|
+
|
|
593
|
+
# Detect shared-dimension case: new_coords has a dim that already
|
|
594
|
+
# exists in the current working set.
|
|
595
|
+
shared_dims = set(new_coords_dims) & set(dims) if new_coords_dims else set()
|
|
596
|
+
|
|
597
|
+
if shared_dims:
|
|
598
|
+
# --- Shared-dimension (pointwise) path ---
|
|
599
|
+
# This path handles cases where the new coordinates share a dimension
|
|
600
|
+
# with the existing data (e.g. interpolating t->z when z already exists).
|
|
601
|
+
#
|
|
602
|
+
# Strategy: Precompute indices/weights once for the shared dimension,
|
|
603
|
+
# then apply them pointwise to each slice. This avoids redundant
|
|
604
|
+
# binary searches across the shared dimension.
|
|
605
|
+
#
|
|
606
|
+
# Limitation: Only single shared dimension is currently supported.
|
|
607
|
+
# Multiple shared dims would require more complex bookkeeping.
|
|
608
|
+
if len(shared_dims) > 1:
|
|
609
|
+
raise NotImplementedError(
|
|
610
|
+
f"Multiple shared dimensions not supported: {shared_dims}. "
|
|
611
|
+
f"Interpolating dimension {dim!r} with coordinates that share "
|
|
612
|
+
f"multiple dimensions with the current data is not implemented."
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
shared_dim = next(iter(shared_dims))
|
|
616
|
+
shared_axis = dims.index(shared_dim)
|
|
617
|
+
shared_size = data.shape[shared_axis]
|
|
618
|
+
|
|
619
|
+
# Transpose to (..., shared_dim, dim)
|
|
620
|
+
other_axes = [
|
|
621
|
+
i for i in range(len(dims)) if i != dim_axis and i != shared_axis
|
|
622
|
+
]
|
|
623
|
+
perm = other_axes + [shared_axis, dim_axis]
|
|
624
|
+
data = data.transpose(perm)
|
|
625
|
+
|
|
626
|
+
# Precompute bin indices and weights once for all shared-dim
|
|
627
|
+
# query points against the (small, uniform) grid. This avoids
|
|
628
|
+
# repeating the binary search for every slice along the leading
|
|
629
|
+
# batch dimensions.
|
|
630
|
+
try:
|
|
631
|
+
indices, weights = lerp_indices(
|
|
632
|
+
old_coords_arr, new_coords_arr, bounds=dim_bounds
|
|
633
|
+
)
|
|
634
|
+
except ValueError as e:
|
|
635
|
+
x_min, x_max = old_coords_arr[0], old_coords_arr[-1]
|
|
636
|
+
q_min, q_max = new_coords_arr.min(), new_coords_arr.max()
|
|
637
|
+
raise ValueError(
|
|
638
|
+
f"Interpolation error on dimension {dim!r}: {e}\n"
|
|
639
|
+
f" Grid range: {[x_min, x_max] =}\n"
|
|
640
|
+
f" Query range: {[q_min, q_max] = }\n"
|
|
641
|
+
f" Bounds mode: {dim_bounds = }"
|
|
642
|
+
) from e
|
|
643
|
+
|
|
644
|
+
# indices/weights are (shared_size,). Reshape to
|
|
645
|
+
# (shared_size, 1) so the gufunc signature (n),(m),(m)->(m)
|
|
646
|
+
# treats the last axis as the core (m=1 query per z-slice)
|
|
647
|
+
# and broadcasts over all leading batch dims.
|
|
648
|
+
indices_bc = indices.reshape((shared_size, 1))
|
|
649
|
+
weights_bc = weights.reshape((shared_size, 1))
|
|
650
|
+
|
|
651
|
+
data = lerp(data, indices_bc, weights_bc)
|
|
652
|
+
|
|
653
|
+
# Handle fill values for bounds="fill" (lerp propagates NaN
|
|
654
|
+
# for OOB weights; replace with the requested fill value).
|
|
655
|
+
if dim_bounds == "fill":
|
|
656
|
+
# Note: data layout is (..., shared_size, 1); OOB mask is
|
|
657
|
+
# over the shared_size axis (second-to-last).
|
|
658
|
+
# Need special indexing for this layout
|
|
659
|
+
oob = (new_coords_arr < old_coords_arr[0]) | (
|
|
660
|
+
new_coords_arr > old_coords_arr[-1]
|
|
661
|
+
)
|
|
662
|
+
if oob.any():
|
|
663
|
+
if isinstance(dim_fill_value, tuple):
|
|
664
|
+
fill_lo, fill_hi = dim_fill_value
|
|
665
|
+
else:
|
|
666
|
+
fill_lo = fill_hi = dim_fill_value
|
|
667
|
+
lo = new_coords_arr < old_coords_arr[0]
|
|
668
|
+
hi = new_coords_arr > old_coords_arr[-1]
|
|
669
|
+
# data layout: (..., shared_size, 1)
|
|
670
|
+
data[..., lo, :] = fill_lo
|
|
671
|
+
data[..., hi, :] = fill_hi
|
|
672
|
+
|
|
673
|
+
# Squeeze the trailing length-1 query axis
|
|
674
|
+
data = data[..., 0]
|
|
675
|
+
|
|
676
|
+
# Update dims: remove dim, keep shared_dim in its original
|
|
677
|
+
# relative position among surviving dims.
|
|
678
|
+
dims = [d for d in dims if d != dim]
|
|
679
|
+
# data is now in layout (other_dims..., shared_dim); transpose
|
|
680
|
+
# back so shared_dim is where it belongs.
|
|
681
|
+
# Build the permutation that undoes the earlier reorder.
|
|
682
|
+
# Current logical order after removal: other_dims + [shared_dim]
|
|
683
|
+
# Target order: dims (which preserved original relative order).
|
|
684
|
+
current_order = [d for d in dims if d != shared_dim] + [shared_dim]
|
|
685
|
+
if current_order != dims:
|
|
686
|
+
perm = [current_order.index(d) for d in dims]
|
|
687
|
+
data = data.transpose(perm)
|
|
688
|
+
|
|
689
|
+
else:
|
|
690
|
+
# --- Standard path (no shared dims) ---
|
|
691
|
+
# Move dim to last axis for the gufunc
|
|
692
|
+
if dim_axis != len(dims) - 1:
|
|
693
|
+
perm = list(range(len(dims)))
|
|
694
|
+
perm.remove(dim_axis)
|
|
695
|
+
perm.append(dim_axis)
|
|
696
|
+
data = data.transpose(perm)
|
|
697
|
+
|
|
698
|
+
# When new_coords is a uniform 1-D array (same query points for
|
|
699
|
+
# every slice along the leading dims), precompute the bin indices
|
|
700
|
+
# and weights once and use the search-free lerp gufunc. This
|
|
701
|
+
# avoids redundant binary searches across spectator dimensions.
|
|
702
|
+
use_precomputed = new_coords_arr.ndim == 1 and data.ndim > 1
|
|
703
|
+
|
|
704
|
+
try:
|
|
705
|
+
if use_precomputed:
|
|
706
|
+
indices, weights = lerp_indices(
|
|
707
|
+
old_coords_arr, new_coords_arr, bounds=dim_bounds
|
|
708
|
+
)
|
|
709
|
+
data = lerp(data, indices, weights)
|
|
710
|
+
# precompute_lerp_indices marks OOB weights as NaN so
|
|
711
|
+
# lerp propagates NaN. Replace with the requested fill
|
|
712
|
+
# value for bounds="fill".
|
|
713
|
+
if dim_bounds == "fill":
|
|
714
|
+
_apply_fill_values(
|
|
715
|
+
data, new_coords_arr, old_coords_arr, dim_fill_value
|
|
716
|
+
)
|
|
717
|
+
else:
|
|
718
|
+
old_bc = np.broadcast_to(
|
|
719
|
+
old_coords_arr,
|
|
720
|
+
data.shape[:-1] + (len(old_coords_arr),),
|
|
721
|
+
)
|
|
722
|
+
new_bc = (
|
|
723
|
+
new_coords_arr.reshape(1)
|
|
724
|
+
if new_coords_arr.ndim == 0
|
|
725
|
+
else new_coords_arr
|
|
726
|
+
)
|
|
727
|
+
data = interp1d(
|
|
728
|
+
old_bc,
|
|
729
|
+
data,
|
|
730
|
+
new_bc,
|
|
731
|
+
bounds=dim_bounds,
|
|
732
|
+
fill_value=dim_fill_value,
|
|
733
|
+
)
|
|
734
|
+
except ValueError as e:
|
|
735
|
+
x_min, x_max = old_coords_arr[0], old_coords_arr[-1]
|
|
736
|
+
q_min = (
|
|
737
|
+
new_coords_arr.min()
|
|
738
|
+
if new_coords_arr.size > 0
|
|
739
|
+
else float("nan")
|
|
740
|
+
)
|
|
741
|
+
q_max = (
|
|
742
|
+
new_coords_arr.max()
|
|
743
|
+
if new_coords_arr.size > 0
|
|
744
|
+
else float("nan")
|
|
745
|
+
)
|
|
746
|
+
raise ValueError(
|
|
747
|
+
f"Interpolation error on dimension {dim!r}: {e}\n"
|
|
748
|
+
f" Grid range: [{x_min}, {x_max}]\n"
|
|
749
|
+
f" Query range: [{q_min}, {q_max}]\n"
|
|
750
|
+
f" Bounds mode: {dim_bounds}"
|
|
751
|
+
) from e
|
|
752
|
+
|
|
753
|
+
if is_scalar:
|
|
754
|
+
# Scalar query: squeeze the last axis, remove dim entirely
|
|
755
|
+
data = data[..., 0]
|
|
756
|
+
# dims without dim, preserving order (dim was moved to end
|
|
757
|
+
# for the gufunc but we track logical order separately)
|
|
758
|
+
dims = [d for d in dims if d != dim]
|
|
759
|
+
elif new_coords_dims:
|
|
760
|
+
# Dimension replaced: swap dim for new_coords_dims at dim_axis.
|
|
761
|
+
# data layout is (...without dim..., len(new_coords_arr))
|
|
762
|
+
# which means new dim is at the end. We want it at dim_axis.
|
|
763
|
+
dims = [d for d in dims if d != dim]
|
|
764
|
+
# Insert new dims at the original position
|
|
765
|
+
for i, nd in enumerate(new_coords_dims):
|
|
766
|
+
dims.insert(dim_axis + i, nd)
|
|
767
|
+
# Move the last axis (new dim) to dim_axis
|
|
768
|
+
n = len(dims)
|
|
769
|
+
perm = list(range(n - 1))
|
|
770
|
+
perm.insert(dim_axis, n - 1)
|
|
771
|
+
data = data.transpose(perm)
|
|
772
|
+
else:
|
|
773
|
+
# Same dim, resized in place. data has dim at the end;
|
|
774
|
+
# move it back to dim_axis.
|
|
775
|
+
if dim_axis != len(dims) - 1:
|
|
776
|
+
n = len(dims)
|
|
777
|
+
perm = list(range(n - 1))
|
|
778
|
+
perm.insert(dim_axis, n - 1)
|
|
779
|
+
data = data.transpose(perm)
|
|
780
|
+
# dims unchanged (same names, same order)
|
|
781
|
+
|
|
782
|
+
# --- Wrap back into a DataArray ---
|
|
783
|
+
# Collect coordinates for the output: keep original coords whose dims
|
|
784
|
+
# are all still present, then add any new coords from interp targets.
|
|
785
|
+
out_coords: dict = {}
|
|
786
|
+
for coord_name, coord_val in da.coords.items():
|
|
787
|
+
if all(d in dims for d in coord_val.dims):
|
|
788
|
+
out_coords[coord_name] = coord_val
|
|
789
|
+
|
|
790
|
+
for spec in interp_specs:
|
|
791
|
+
dim = spec["dim"]
|
|
792
|
+
if spec["is_scalar"]:
|
|
793
|
+
continue
|
|
794
|
+
if spec["new_dims"]:
|
|
795
|
+
# Coordinates from the replacement DataArray
|
|
796
|
+
out_coords.update(spec["new_coords"])
|
|
797
|
+
else:
|
|
798
|
+
# Plain array: attach as a coordinate on its own dim
|
|
799
|
+
out_coords[dim] = (dim, spec["arr"])
|
|
800
|
+
|
|
801
|
+
return xr.DataArray(
|
|
802
|
+
data, dims=dims, coords=out_coords, name=da.name, attrs=da.attrs
|
|
803
|
+
)
|