axsdb 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
axsdb/interpolation.py ADDED
@@ -0,0 +1,803 @@
1
+ """
2
+ Fast xarray interpolation.
3
+
4
+ This module provides high-performance interpolation functions for xarray
5
+ DataArrays that bypass xarray's built-in interpolation. This is motivated
6
+ by performance regressions in recent xarray versions (see
7
+ https://github.com/pydata/xarray/issues/10683).
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from collections.abc import Hashable
13
+ from typing import Literal
14
+
15
+ import numpy as np
16
+ import xarray as xr
17
+ from scipy.interpolate import interpn
18
+
19
+ from .math import interp1d, lerp, lerp_indices
20
+
21
+
22
+ def _apply_fill_values(
23
+ data: np.ndarray,
24
+ new_coords_arr: np.ndarray,
25
+ old_coords_arr: np.ndarray,
26
+ dim_fill_value: float | tuple[float, float],
27
+ ) -> None:
28
+ """
29
+ Replace NaN values with specified fill values for out-of-bounds points.
30
+
31
+ Modifies data in-place. Assumes NaN has already been set for OOB points
32
+ by the interpolation functions.
33
+
34
+ Parameters
35
+ ----------
36
+ data : ndarray
37
+ Data array with NaN for out-of-bounds points. Modified in-place.
38
+
39
+ new_coords_arr : ndarray
40
+ Query coordinates
41
+
42
+ old_coords_arr : ndarray
43
+ Original grid coordinates
44
+
45
+ dim_fill_value : float or tuple
46
+ Fill value(s) to use. If tuple, (``fill_lower``, ``fill_upper``).
47
+ """
48
+ oob = (new_coords_arr < old_coords_arr[0]) | (new_coords_arr > old_coords_arr[-1])
49
+ if not oob.any():
50
+ return
51
+
52
+ if isinstance(dim_fill_value, tuple):
53
+ fill_lo, fill_hi = dim_fill_value
54
+ else:
55
+ fill_lo = fill_hi = dim_fill_value
56
+
57
+ lo = new_coords_arr < old_coords_arr[0]
58
+ hi = new_coords_arr > old_coords_arr[-1]
59
+
60
+ # Apply fill values along the last axis
61
+ data[..., lo] = fill_lo
62
+ data[..., hi] = fill_hi
63
+
64
+
65
+ def _decompose_interp_groups(
66
+ interp_specs: list[dict],
67
+ ) -> list[list[dict]]:
68
+ """
69
+ Decompose interpolation specs into groups based on destination dimensions.
70
+
71
+ Groups specs so that dimensions with independent destinations are in
72
+ separate groups, while dimensions sharing a destination are grouped together.
73
+ This enables using scipy.interpn for multi-dimensional interpolation when
74
+ multiple source dimensions map to the same destination dimension.
75
+
76
+ Parameters
77
+ ----------
78
+ interp_specs : list of dict
79
+ List of interpolation specs, each with 'dim', 'arr', 'new_dims', etc.
80
+
81
+ Returns
82
+ -------
83
+ list of list of dict
84
+ Groups of specs. Each group can be processed together.
85
+ """
86
+ if not interp_specs:
87
+ return []
88
+
89
+ groups: list[list[dict]] = []
90
+ current_group: list[dict] = []
91
+ current_dest_dims: set = set()
92
+
93
+ for spec in interp_specs:
94
+ spec_dest_dims = set(spec["new_dims"]) if spec["new_dims"] else {spec["dim"]}
95
+
96
+ if not current_group:
97
+ # First spec starts a new group
98
+ current_group = [spec]
99
+ current_dest_dims = spec_dest_dims
100
+ elif spec_dest_dims & current_dest_dims:
101
+ # Shares destination with current group - add to it
102
+ current_group.append(spec)
103
+ current_dest_dims |= spec_dest_dims
104
+ else:
105
+ # Independent of current group - yield current and start new
106
+ groups.append(current_group)
107
+ current_group = [spec]
108
+ current_dest_dims = spec_dest_dims
109
+
110
+ if current_group:
111
+ groups.append(current_group)
112
+
113
+ return groups
114
+
115
+
116
+ def _check_uniform_bounds(
117
+ group: list[dict],
118
+ bounds_dict: dict[Hashable, Literal["fill", "clamp", "raise"]],
119
+ ) -> tuple[bool, Literal["fill", "clamp", "raise"] | None]:
120
+ """
121
+ Check if all dimensions in a group have uniform bounds mode.
122
+
123
+ Parameters
124
+ ----------
125
+ group : list of dict
126
+ Group of interpolation specs.
127
+ bounds_dict : dict
128
+ Mapping from dimension names to bounds modes.
129
+
130
+ Returns
131
+ -------
132
+ is_uniform : bool
133
+ True if all dimensions have the same bounds mode.
134
+
135
+ uniform_mode : {"fill", "clamp", "raise"} or None
136
+ The shared bounds mode, or None if not uniform.
137
+ """
138
+ if len(group) == 0:
139
+ return False, None
140
+
141
+ # Extract bounds mode for first spec
142
+ first_dim = group[0]["dim"]
143
+ first_mode = bounds_dict.get(first_dim, "fill")
144
+
145
+ # Check if all other specs have same bounds mode
146
+ for spec in group[1:]:
147
+ dim = spec["dim"]
148
+ mode = bounds_dict.get(dim, "fill")
149
+ if mode != first_mode:
150
+ return False, None
151
+
152
+ return True, first_mode
153
+
154
+
155
+ def _should_use_fast_path(
156
+ group: list[dict],
157
+ dims: list[Hashable],
158
+ bounds_dict: dict[Hashable, Literal["fill", "clamp", "raise"]],
159
+ ) -> bool:
160
+ """
161
+ Determine if group qualifies for scipy.interpn fast path.
162
+
163
+ Parameters
164
+ ----------
165
+ group : list of dict
166
+ Group of interpolation specs.
167
+ dims : list
168
+ Current dimension names in the data.
169
+ bounds_dict : dict
170
+ Mapping from dimension names to bounds modes.
171
+
172
+ Returns
173
+ -------
174
+ bool
175
+ True if the group qualifies for fast path processing.
176
+ """
177
+ # Must have multiple specs in group
178
+ if len(group) <= 1:
179
+ return False
180
+
181
+ # All specs must have same non-empty destination dims
182
+ if not group[0]["new_dims"]:
183
+ return False
184
+
185
+ dest_dims = group[0]["new_dims"]
186
+ if not all(spec["new_dims"] == dest_dims for spec in group):
187
+ return False
188
+
189
+ # Destination dims must not already exist (else need pointwise path)
190
+ if set(dest_dims) & set(dims):
191
+ return False
192
+
193
+ # No scalar specs allowed
194
+ if any(spec["is_scalar"] for spec in group):
195
+ return False
196
+
197
+ # All dimensions must have uniform bounds mode
198
+ is_uniform, _ = _check_uniform_bounds(group, bounds_dict)
199
+ if not is_uniform:
200
+ return False
201
+
202
+ return True
203
+
204
+
205
+ def _interp_group_with_interpn(
206
+ data: np.ndarray,
207
+ dims: list[Hashable],
208
+ da: xr.DataArray,
209
+ group: list[dict],
210
+ bounds_mode: Literal["fill", "clamp", "raise"],
211
+ fill_value_dict: dict[Hashable, float | tuple[float, float]],
212
+ ) -> tuple[np.ndarray, list[Hashable]]:
213
+ """
214
+ Interpolate multiple dimensions at once using scipy.interpn.
215
+
216
+ This is used when multiple source dimensions share a common destination
217
+ dimension (e.g., t, p, x_H2O all mapping to z) and have uniform bounds
218
+ mode. Using interpn is much faster than sequential 1D interpolation.
219
+
220
+ Parameters
221
+ ----------
222
+ data : np.ndarray
223
+ Current data array.
224
+ dims : list
225
+ Current dimension names corresponding to data axes.
226
+ da : xr.DataArray
227
+ Original DataArray (for coordinate grids).
228
+ group : list of dict
229
+ Group of specs to interpolate together.
230
+ bounds_mode : {"fill", "clamp", "raise"}
231
+ Uniform bounds mode for all dimensions in group.
232
+ fill_value_dict : dict
233
+ Per-dimension fill values (used for "fill" mode).
234
+
235
+ Returns
236
+ -------
237
+ data : np.ndarray
238
+ Interpolated data.
239
+ dims : list
240
+ Updated dimension names.
241
+ """
242
+ # Identify source dims and destination dims
243
+ src_dims = [spec["dim"] for spec in group]
244
+ dest_dims = group[0]["new_dims"] # All specs in group share destination dims
245
+
246
+ # Build grid points tuple for interpn (order matters!)
247
+ grid_points = tuple(da.coords[dim].values for dim in src_dims)
248
+
249
+ # Build query points array: shape (dest_size, n_src_dims)
250
+ query_arrays = [spec["arr"] for spec in group]
251
+ xi = np.stack(query_arrays, axis=-1)
252
+
253
+ # Find axes of source dims in current data layout
254
+ src_axes = [dims.index(dim) for dim in src_dims]
255
+ other_axes = [i for i in range(len(dims)) if i not in src_axes]
256
+ other_dims = [dims[i] for i in other_axes]
257
+
258
+ # Transpose: src_dims first, then other dims (batch dims at the end for interpn)
259
+ perm = src_axes + other_axes
260
+ data_transposed = data.transpose(perm)
261
+
262
+ # Handle bounds mode uniformly
263
+ if bounds_mode == "raise":
264
+ # Pre-validate all dimensions before interpolation
265
+ for i, (dim, grid) in enumerate(zip(src_dims, grid_points)):
266
+ xi_dim = xi[..., i]
267
+ xi_min, xi_max = xi_dim.min(), xi_dim.max()
268
+ grid_min, grid_max = grid[0], grid[-1]
269
+
270
+ if xi_min < grid_min or xi_max > grid_max:
271
+ raise ValueError(
272
+ f"Interpolation error on dimension {dim!r}: "
273
+ f"Query points out of bounds.\n"
274
+ f" Grid range: [{grid_min}, {grid_max}]\n"
275
+ f" Query range: [{xi_min}, {xi_max}]\n"
276
+ f" Bounds mode: raise"
277
+ )
278
+ # All checks passed - call interpn with bounds_error=False
279
+ result = interpn(
280
+ grid_points,
281
+ data_transposed,
282
+ xi,
283
+ method="linear",
284
+ bounds_error=False,
285
+ fill_value=np.nan,
286
+ )
287
+
288
+ elif bounds_mode == "clamp":
289
+ # Clamp query points to grid bounds
290
+ xi_clamped = xi.copy()
291
+ for i, grid in enumerate(grid_points):
292
+ xi_clamped[..., i] = np.clip(xi_clamped[..., i], grid[0], grid[-1])
293
+ result = interpn(
294
+ grid_points,
295
+ data_transposed,
296
+ xi_clamped,
297
+ method="linear",
298
+ bounds_error=False,
299
+ fill_value=np.nan,
300
+ )
301
+
302
+ else: # bounds_mode == "fill"
303
+ # Call interpn with bounds_error=False to get NaN for OOB
304
+ result = interpn(
305
+ grid_points,
306
+ data_transposed,
307
+ xi,
308
+ method="linear",
309
+ bounds_error=False,
310
+ fill_value=np.nan,
311
+ )
312
+
313
+ # Post-process per-dimension fill values
314
+ # For each dimension, replace NaN where that dim was OOB
315
+ for i, (dim, grid) in enumerate(zip(src_dims, grid_points)):
316
+ dim_fill_value = fill_value_dict.get(dim, np.nan)
317
+
318
+ # Skip if fill value is already NaN (no-op)
319
+ if isinstance(dim_fill_value, tuple):
320
+ fill_lo, fill_hi = dim_fill_value
321
+ else:
322
+ fill_lo = fill_hi = dim_fill_value
323
+
324
+ if not (np.isnan(fill_lo) and np.isnan(fill_hi)):
325
+ # Identify OOB points for this dimension
326
+ xi_dim = xi[..., i]
327
+ grid_min, grid_max = grid[0], grid[-1]
328
+
329
+ # Apply fill values
330
+ # Note: result shape is (dest_size, *batch_dims)
331
+ oob_lo = xi_dim < grid_min
332
+ oob_hi = xi_dim > grid_max
333
+
334
+ # Broadcast OOB mask to result shape
335
+ # xi_dim shape: (dest_size,), result shape: (dest_size, *batch)
336
+ if oob_lo.any():
337
+ result[oob_lo, ...] = fill_lo
338
+ if oob_hi.any():
339
+ result[oob_hi, ...] = fill_hi
340
+
341
+ # Update dims: replace src_dims with dest_dims
342
+ new_dims = list(dest_dims) + other_dims
343
+
344
+ return result, new_dims
345
+
346
+
347
+ def interp_dataarray(
348
+ da: xr.DataArray,
349
+ coords: dict[Hashable, float | np.ndarray | xr.DataArray],
350
+ bounds: Literal["fill", "clamp", "raise"]
351
+ | dict[Hashable, Literal["fill", "clamp", "raise"]] = "fill",
352
+ fill_value: float
353
+ | tuple[float, float]
354
+ | dict[Hashable, float | tuple[float, float]] = np.nan,
355
+ ) -> xr.DataArray:
356
+ """
357
+ Fast linear interpolation for xarray DataArrays.
358
+
359
+ This function provides a high-performance alternative to xarray's
360
+ ``.interp()`` method for linear interpolation, with the possibility to
361
+ control out-of-bound handling behaviour per dimension.
362
+
363
+ Parameters
364
+ ----------
365
+ da : DataArray
366
+ The input DataArray to interpolate.
367
+
368
+ coords : dict
369
+ Mapping from dimension names to new coordinate values. Each value
370
+ can be a scalar, numpy array, or xarray DataArray. Dimensions are
371
+ interpolated sequentially in the order given.
372
+
373
+ bounds : {"fill", "clamp", "raise"} or dict, default: "fill"
374
+ How to handle out-of-bounds query points. Can be a single value
375
+ applied to all dimensions, or a dict mapping dimension names to
376
+ bounds modes. Missing dimensions default to ``"fill"``.
377
+
378
+ * "fill": Use ``fill_value`` for points outside the data range.
379
+ * "clamp": Use the nearest boundary value.
380
+ * "raise": Raise a ValueError if any query point is out of bounds.
381
+
382
+ fill_value : float or tuple or dict, default: np.nan
383
+ Value(s) to use for out-of-bounds points when ``bounds="fill"``.
384
+ Can be:
385
+
386
+ * a single float (used for both lower and upper bounds, all dims);
387
+ * a 2-tuple (``fill_lower``, ``fill_upper``) for all dimensions;
388
+ * a dict mapping dimension names to floats or 2-tuples.
389
+
390
+ Missing dimensions default to ``np.nan``.
391
+
392
+ Returns
393
+ -------
394
+ DataArray
395
+ Interpolated DataArray with the new coordinates. Preserves the
396
+ original DataArray's name and attributes.
397
+
398
+ Raises
399
+ ------
400
+ ValueError
401
+ If a dimension in ``coords`` does not exist in the DataArray.
402
+ If ``bounds="raise"`` and any query point is out of bounds.
403
+ The error message includes the dimension name for easier debugging.
404
+
405
+ Notes
406
+ -----
407
+ The function assumes coordinates are sorted in ascending order along
408
+ each interpolation dimension. Results are undefined if this assumption
409
+ is violated.
410
+
411
+ Interpolation is performed sequentially across dimensions in the order
412
+ given in ``coords``. This is equivalent to chained ``.interp()`` calls
413
+ but significantly faster due to the optimized gufunc implementation.
414
+
415
+ Examples
416
+ --------
417
+ Basic multi-dimensional interpolation:
418
+
419
+ >>> import numpy as np
420
+ >>> import xarray as xr
421
+ >>> from axsdb.interpolation import interp_dataarray
422
+ >>>
423
+ >>> # Create sample data
424
+ >>> da = xr.DataArray(
425
+ ... np.random.rand(10, 20, 30),
426
+ ... dims=["wavelength", "angle", "time"],
427
+ ... coords={
428
+ ... "wavelength": np.linspace(400, 700, 10),
429
+ ... "angle": np.linspace(0, 90, 20),
430
+ ... "time": np.linspace(0, 1, 30),
431
+ ... },
432
+ ... )
433
+ >>>
434
+ >>> # Interpolate to new coordinates
435
+ >>> result = interp_dataarray(
436
+ ... da,
437
+ ... {
438
+ ... "wavelength": np.array([450.0, 550.0, 650.0]),
439
+ ... "angle": np.array([30.0, 60.0]),
440
+ ... },
441
+ ... )
442
+ >>> result.shape
443
+ (3, 2, 30)
444
+
445
+ Different bounds handling per dimension:
446
+
447
+ >>> result = interp_dataarray(
448
+ ... da,
449
+ ... {"wavelength": new_wavelengths, "angle": new_angles},
450
+ ... bounds={"wavelength": "raise", "angle": "clamp"},
451
+ ... )
452
+
453
+ Custom fill values per dimension:
454
+
455
+ >>> result = interp_dataarray(
456
+ ... da,
457
+ ... {"wavelength": new_wavelengths, "angle": new_angles},
458
+ ... bounds="fill",
459
+ ... fill_value={"wavelength": 0.0, "angle": (-1.0, 1.0)},
460
+ ... )
461
+ """
462
+ # Normalize bounds to dict format
463
+ if isinstance(bounds, str):
464
+ bounds_dict: dict[Hashable, Literal["fill", "clamp", "raise"]] = {
465
+ dim: bounds for dim in coords
466
+ }
467
+ else:
468
+ bounds_dict = dict(bounds)
469
+ for dim in coords:
470
+ if dim not in bounds_dict:
471
+ bounds_dict[dim] = "fill"
472
+
473
+ # Normalize fill_value to dict format
474
+ if isinstance(fill_value, (int, float)) or (
475
+ isinstance(fill_value, tuple) and len(fill_value) == 2
476
+ ):
477
+ fill_value_dict: dict[Hashable, float | tuple[float, float]] = {
478
+ dim: fill_value for dim in coords
479
+ }
480
+ else:
481
+ fill_value_dict = dict(fill_value) # type: ignore[arg-type]
482
+ for dim in coords:
483
+ if dim not in fill_value_dict:
484
+ fill_value_dict[dim] = np.nan
485
+
486
+ # --- Pre-process all interpolation targets once ---
487
+ # For each dimension we resolve the new coordinates to a plain numpy
488
+ # array and record the metadata needed for the final DataArray wrap.
489
+ interp_specs: list[dict] = []
490
+ for dim, new_coords in coords.items():
491
+ if dim not in da.dims:
492
+ raise ValueError(
493
+ f"Dimension {dim!r} not found in DataArray. "
494
+ f"Available dimensions: {list(da.dims)}"
495
+ )
496
+ if isinstance(new_coords, xr.DataArray):
497
+ spec = {
498
+ "dim": dim,
499
+ "arr": new_coords.values,
500
+ "new_dims": new_coords.dims,
501
+ "new_coords": {
502
+ k: v for k, v in new_coords.coords.items() if k in new_coords.dims
503
+ },
504
+ "is_scalar": False,
505
+ }
506
+ elif np.isscalar(new_coords):
507
+ spec = {
508
+ "dim": dim,
509
+ "arr": np.array([new_coords]),
510
+ "new_dims": (),
511
+ "new_coords": {},
512
+ "is_scalar": True,
513
+ }
514
+ else:
515
+ arr = np.asarray(new_coords)
516
+ spec = {
517
+ "dim": dim,
518
+ "arr": arr,
519
+ "new_dims": (),
520
+ "new_coords": {},
521
+ "is_scalar": arr.ndim == 0,
522
+ }
523
+ if spec["is_scalar"]:
524
+ spec["arr"] = arr.reshape(1)
525
+ interp_specs.append(spec)
526
+
527
+ # --- Reorder: shrinking dimensions before expanding ones ---
528
+ # Interpolating a dimension that expands (query size > grid size) while
529
+ # other spectator dimensions are still large multiplies work
530
+ # unnecessarily. Processing shrinking dims first keeps the intermediate
531
+ # array small until the expansion happens on a much smaller array.
532
+ #
533
+ # Within expanding dimensions, process LARGER grids first. This minimizes
534
+ # the intermediate array size when introducing a new dimension (e.g. z).
535
+ # The first expander creates (total / grid_size) * query_size elements;
536
+ # larger grid_size means smaller intermediate. Subsequent expanders
537
+ # share the new dimension and reduce the array; processing larger grids
538
+ # first means fewer lerp operations overall.
539
+ def _interp_sort_key(spec: dict) -> tuple:
540
+ grid_size = da.sizes[spec["dim"]]
541
+ query_size = len(spec["arr"])
542
+ # (0, ...) = shrink/same -> first; (1, ...) = expand -> last
543
+ # Within each group, larger grid_size first (hence negative)
544
+ if query_size <= grid_size:
545
+ return (0, -grid_size)
546
+ else:
547
+ return (1, -grid_size)
548
+
549
+ interp_specs.sort(key=_interp_sort_key)
550
+
551
+ # --- Decompose into groups for potential fast path ---
552
+ # Group specs by destination dimensions. When multiple source dims map to
553
+ # the same destination (e.g. t, p, x_H2O all -> z), processing them together
554
+ # with scipy.interpn is much faster than sequential 1D interpolation.
555
+ interp_groups = _decompose_interp_groups(interp_specs)
556
+
557
+ # --- Main interpolation loop on raw numpy arrays ---
558
+ # `dims` tracks the current logical dimension order.
559
+ # `data` is the raw ndarray; axes correspond 1-to-1 with `dims`.
560
+ dims: list[Hashable] = list(da.dims)
561
+ data: np.ndarray = da.values
562
+
563
+ for group in interp_groups:
564
+ # Check if this group can use the fast interpn path:
565
+ # - Multiple specs in the group
566
+ # - All specs have the same non-empty new_dims (shared destination)
567
+ # - The destination dims don't already exist in the current array
568
+ # (if they do, we need the pointwise/shared path)
569
+ # - All dimensions have uniform bounds mode
570
+ use_interpn = _should_use_fast_path(group, dims, bounds_dict)
571
+
572
+ if use_interpn:
573
+ # Fast path: multi-dimensional interpn
574
+ is_uniform, uniform_mode = _check_uniform_bounds(group, bounds_dict)
575
+ data, dims = _interp_group_with_interpn(
576
+ data, dims, da, group, uniform_mode, fill_value_dict
577
+ )
578
+ continue
579
+
580
+ # Sequential path: process each spec one at a time
581
+ for spec in group:
582
+ dim = spec["dim"]
583
+ new_coords_arr: np.ndarray = spec["arr"]
584
+ new_coords_dims: tuple = spec["new_dims"]
585
+ is_scalar: bool = spec["is_scalar"]
586
+
587
+ dim_bounds = bounds_dict.get(dim, "fill")
588
+ dim_fill_value = fill_value_dict.get(dim, np.nan)
589
+
590
+ dim_axis = dims.index(dim)
591
+ old_coords_arr = da.coords[dim].values
592
+
593
+ # Detect shared-dimension case: new_coords has a dim that already
594
+ # exists in the current working set.
595
+ shared_dims = set(new_coords_dims) & set(dims) if new_coords_dims else set()
596
+
597
+ if shared_dims:
598
+ # --- Shared-dimension (pointwise) path ---
599
+ # This path handles cases where the new coordinates share a dimension
600
+ # with the existing data (e.g. interpolating t->z when z already exists).
601
+ #
602
+ # Strategy: Precompute indices/weights once for the shared dimension,
603
+ # then apply them pointwise to each slice. This avoids redundant
604
+ # binary searches across the shared dimension.
605
+ #
606
+ # Limitation: Only single shared dimension is currently supported.
607
+ # Multiple shared dims would require more complex bookkeeping.
608
+ if len(shared_dims) > 1:
609
+ raise NotImplementedError(
610
+ f"Multiple shared dimensions not supported: {shared_dims}. "
611
+ f"Interpolating dimension {dim!r} with coordinates that share "
612
+ f"multiple dimensions with the current data is not implemented."
613
+ )
614
+
615
+ shared_dim = next(iter(shared_dims))
616
+ shared_axis = dims.index(shared_dim)
617
+ shared_size = data.shape[shared_axis]
618
+
619
+ # Transpose to (..., shared_dim, dim)
620
+ other_axes = [
621
+ i for i in range(len(dims)) if i != dim_axis and i != shared_axis
622
+ ]
623
+ perm = other_axes + [shared_axis, dim_axis]
624
+ data = data.transpose(perm)
625
+
626
+ # Precompute bin indices and weights once for all shared-dim
627
+ # query points against the (small, uniform) grid. This avoids
628
+ # repeating the binary search for every slice along the leading
629
+ # batch dimensions.
630
+ try:
631
+ indices, weights = lerp_indices(
632
+ old_coords_arr, new_coords_arr, bounds=dim_bounds
633
+ )
634
+ except ValueError as e:
635
+ x_min, x_max = old_coords_arr[0], old_coords_arr[-1]
636
+ q_min, q_max = new_coords_arr.min(), new_coords_arr.max()
637
+ raise ValueError(
638
+ f"Interpolation error on dimension {dim!r}: {e}\n"
639
+ f" Grid range: {[x_min, x_max] =}\n"
640
+ f" Query range: {[q_min, q_max] = }\n"
641
+ f" Bounds mode: {dim_bounds = }"
642
+ ) from e
643
+
644
+ # indices/weights are (shared_size,). Reshape to
645
+ # (shared_size, 1) so the gufunc signature (n),(m),(m)->(m)
646
+ # treats the last axis as the core (m=1 query per z-slice)
647
+ # and broadcasts over all leading batch dims.
648
+ indices_bc = indices.reshape((shared_size, 1))
649
+ weights_bc = weights.reshape((shared_size, 1))
650
+
651
+ data = lerp(data, indices_bc, weights_bc)
652
+
653
+ # Handle fill values for bounds="fill" (lerp propagates NaN
654
+ # for OOB weights; replace with the requested fill value).
655
+ if dim_bounds == "fill":
656
+ # Note: data layout is (..., shared_size, 1); OOB mask is
657
+ # over the shared_size axis (second-to-last).
658
+ # Need special indexing for this layout
659
+ oob = (new_coords_arr < old_coords_arr[0]) | (
660
+ new_coords_arr > old_coords_arr[-1]
661
+ )
662
+ if oob.any():
663
+ if isinstance(dim_fill_value, tuple):
664
+ fill_lo, fill_hi = dim_fill_value
665
+ else:
666
+ fill_lo = fill_hi = dim_fill_value
667
+ lo = new_coords_arr < old_coords_arr[0]
668
+ hi = new_coords_arr > old_coords_arr[-1]
669
+ # data layout: (..., shared_size, 1)
670
+ data[..., lo, :] = fill_lo
671
+ data[..., hi, :] = fill_hi
672
+
673
+ # Squeeze the trailing length-1 query axis
674
+ data = data[..., 0]
675
+
676
+ # Update dims: remove dim, keep shared_dim in its original
677
+ # relative position among surviving dims.
678
+ dims = [d for d in dims if d != dim]
679
+ # data is now in layout (other_dims..., shared_dim); transpose
680
+ # back so shared_dim is where it belongs.
681
+ # Build the permutation that undoes the earlier reorder.
682
+ # Current logical order after removal: other_dims + [shared_dim]
683
+ # Target order: dims (which preserved original relative order).
684
+ current_order = [d for d in dims if d != shared_dim] + [shared_dim]
685
+ if current_order != dims:
686
+ perm = [current_order.index(d) for d in dims]
687
+ data = data.transpose(perm)
688
+
689
+ else:
690
+ # --- Standard path (no shared dims) ---
691
+ # Move dim to last axis for the gufunc
692
+ if dim_axis != len(dims) - 1:
693
+ perm = list(range(len(dims)))
694
+ perm.remove(dim_axis)
695
+ perm.append(dim_axis)
696
+ data = data.transpose(perm)
697
+
698
+ # When new_coords is a uniform 1-D array (same query points for
699
+ # every slice along the leading dims), precompute the bin indices
700
+ # and weights once and use the search-free lerp gufunc. This
701
+ # avoids redundant binary searches across spectator dimensions.
702
+ use_precomputed = new_coords_arr.ndim == 1 and data.ndim > 1
703
+
704
+ try:
705
+ if use_precomputed:
706
+ indices, weights = lerp_indices(
707
+ old_coords_arr, new_coords_arr, bounds=dim_bounds
708
+ )
709
+ data = lerp(data, indices, weights)
710
+ # precompute_lerp_indices marks OOB weights as NaN so
711
+ # lerp propagates NaN. Replace with the requested fill
712
+ # value for bounds="fill".
713
+ if dim_bounds == "fill":
714
+ _apply_fill_values(
715
+ data, new_coords_arr, old_coords_arr, dim_fill_value
716
+ )
717
+ else:
718
+ old_bc = np.broadcast_to(
719
+ old_coords_arr,
720
+ data.shape[:-1] + (len(old_coords_arr),),
721
+ )
722
+ new_bc = (
723
+ new_coords_arr.reshape(1)
724
+ if new_coords_arr.ndim == 0
725
+ else new_coords_arr
726
+ )
727
+ data = interp1d(
728
+ old_bc,
729
+ data,
730
+ new_bc,
731
+ bounds=dim_bounds,
732
+ fill_value=dim_fill_value,
733
+ )
734
+ except ValueError as e:
735
+ x_min, x_max = old_coords_arr[0], old_coords_arr[-1]
736
+ q_min = (
737
+ new_coords_arr.min()
738
+ if new_coords_arr.size > 0
739
+ else float("nan")
740
+ )
741
+ q_max = (
742
+ new_coords_arr.max()
743
+ if new_coords_arr.size > 0
744
+ else float("nan")
745
+ )
746
+ raise ValueError(
747
+ f"Interpolation error on dimension {dim!r}: {e}\n"
748
+ f" Grid range: [{x_min}, {x_max}]\n"
749
+ f" Query range: [{q_min}, {q_max}]\n"
750
+ f" Bounds mode: {dim_bounds}"
751
+ ) from e
752
+
753
+ if is_scalar:
754
+ # Scalar query: squeeze the last axis, remove dim entirely
755
+ data = data[..., 0]
756
+ # dims without dim, preserving order (dim was moved to end
757
+ # for the gufunc but we track logical order separately)
758
+ dims = [d for d in dims if d != dim]
759
+ elif new_coords_dims:
760
+ # Dimension replaced: swap dim for new_coords_dims at dim_axis.
761
+ # data layout is (...without dim..., len(new_coords_arr))
762
+ # which means new dim is at the end. We want it at dim_axis.
763
+ dims = [d for d in dims if d != dim]
764
+ # Insert new dims at the original position
765
+ for i, nd in enumerate(new_coords_dims):
766
+ dims.insert(dim_axis + i, nd)
767
+ # Move the last axis (new dim) to dim_axis
768
+ n = len(dims)
769
+ perm = list(range(n - 1))
770
+ perm.insert(dim_axis, n - 1)
771
+ data = data.transpose(perm)
772
+ else:
773
+ # Same dim, resized in place. data has dim at the end;
774
+ # move it back to dim_axis.
775
+ if dim_axis != len(dims) - 1:
776
+ n = len(dims)
777
+ perm = list(range(n - 1))
778
+ perm.insert(dim_axis, n - 1)
779
+ data = data.transpose(perm)
780
+ # dims unchanged (same names, same order)
781
+
782
+ # --- Wrap back into a DataArray ---
783
+ # Collect coordinates for the output: keep original coords whose dims
784
+ # are all still present, then add any new coords from interp targets.
785
+ out_coords: dict = {}
786
+ for coord_name, coord_val in da.coords.items():
787
+ if all(d in dims for d in coord_val.dims):
788
+ out_coords[coord_name] = coord_val
789
+
790
+ for spec in interp_specs:
791
+ dim = spec["dim"]
792
+ if spec["is_scalar"]:
793
+ continue
794
+ if spec["new_dims"]:
795
+ # Coordinates from the replacement DataArray
796
+ out_coords.update(spec["new_coords"])
797
+ else:
798
+ # Plain array: attach as a coordinate on its own dim
799
+ out_coords[dim] = (dim, spec["arr"])
800
+
801
+ return xr.DataArray(
802
+ data, dims=dims, coords=out_coords, name=da.name, attrs=da.attrs
803
+ )