eo-tides 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eo_tides/utils.py ADDED
@@ -0,0 +1,705 @@
1
+ # Used to postpone evaluation of type annotations
2
+ from __future__ import annotations
3
+
4
+ import datetime
5
+ import os
6
+ import pathlib
7
+ import textwrap
8
+ import warnings
9
+ from collections import Counter
10
+ from typing import List, Union
11
+
12
+ import numpy as np
13
+ import odc.geo
14
+ import pandas as pd
15
+ import xarray as xr
16
+ from colorama import Style, init
17
+ from odc.geo.geom import BoundingBox
18
+ from pyTMD.io.model import load_database
19
+ from pyTMD.io.model import model as pytmd_model
20
+ from scipy.spatial import cKDTree as KDTree
21
+ from tqdm import tqdm
22
+
23
+ # Type alias for all possible inputs to "time" params
24
+ DatetimeLike = Union[np.ndarray, pd.DatetimeIndex, pd.Timestamp, datetime.datetime, str, List[str]]
25
+
26
+
27
+ def _get_duplicates(array):
28
+ """
29
+ Return any duplicates in a list or array.
30
+ """
31
+ c = Counter(array)
32
+ return [k for k in c if c[k] > 1]
33
+
34
+
35
+ def _set_directory(
36
+ directory: str | os.PathLike | None = None,
37
+ ) -> os.PathLike:
38
+ """
39
+ Set tide modelling files directory. If no custom
40
+ path is provided, try global `EO_TIDES_TIDE_MODELS`
41
+ environmental variable instead.
42
+ """
43
+ if directory is None:
44
+ if "EO_TIDES_TIDE_MODELS" in os.environ:
45
+ directory = os.environ["EO_TIDES_TIDE_MODELS"]
46
+ else:
47
+ raise Exception(
48
+ "No tide model directory provided via `directory`, and/or no "
49
+ "`EO_TIDES_TIDE_MODELS` environment variable found. "
50
+ "Please provide a valid path to your tide model directory."
51
+ )
52
+
53
+ # Verify path exists
54
+ directory = pathlib.Path(directory).expanduser()
55
+ if not directory.exists():
56
+ raise FileNotFoundError(f"No valid tide model directory found at path `{directory}`")
57
+ else:
58
+ return directory
59
+
60
+
61
+ def _standardise_time(
62
+ time: DatetimeLike | None,
63
+ ) -> np.ndarray | None:
64
+ """
65
+ Accept any time format accepted by `pd.to_datetime`,
66
+ and return a datetime64 ndarray. Return None if None
67
+ passed.
68
+ """
69
+ # Return time as-is if None
70
+ if time is None:
71
+ return None
72
+
73
+ # Use pd.to_datetime for conversion, then convert to numpy array
74
+ time = pd.to_datetime(time).to_numpy().astype("datetime64[ns]")
75
+
76
+ # Ensure that data has at least one dimension
77
+ return np.atleast_1d(time)
78
+
79
+
80
+ def _standardise_models(
81
+ model: str | list[str],
82
+ directory: str | os.PathLike,
83
+ ensemble_models: list[str] | None = None,
84
+ ) -> tuple[list[str], list[str], list[str] | None]:
85
+ """
86
+ Take an input model name or list of names, and return a list
87
+ of models to process, requested models, and ensemble models,
88
+ as required by the `model_tides` function.
89
+
90
+ Handles two special values passed to `model`: "all", which
91
+ will model tides for all models available in `directory`, and
92
+ "ensemble", which will model tides for all models in a list
93
+ of custom ensemble models.
94
+ """
95
+
96
+ # Turn inputs into arrays for consistent handling
97
+ models_requested = list(np.atleast_1d(model))
98
+
99
+ # Raise error if list contains duplications
100
+ duplicates = _get_duplicates(models_requested)
101
+ if len(duplicates) > 0:
102
+ raise ValueError(f"The model parameter contains duplicate values: {duplicates}")
103
+
104
+ # Get full list of supported models from pyTMD database
105
+ available_models, valid_models = list_models(
106
+ directory, show_available=False, show_supported=False, raise_error=True
107
+ )
108
+ custom_options = ["ensemble", "all"]
109
+
110
+ # Error if any models are not supported
111
+ if not all(m in valid_models + custom_options for m in models_requested):
112
+ error_text = (
113
+ f"One or more of the requested models are not valid:\n"
114
+ f"{models_requested}\n\n"
115
+ "The following models are supported:\n"
116
+ f"{valid_models}"
117
+ )
118
+ raise ValueError(error_text)
119
+
120
+ # Error if any models are not available in `directory`
121
+ if not all(m in available_models + custom_options for m in models_requested):
122
+ error_text = (
123
+ f"One or more of the requested models are valid, but not available in `{directory}`:\n"
124
+ f"{models_requested}\n\n"
125
+ f"The following models are available in `{directory}`:\n"
126
+ f"{available_models}"
127
+ )
128
+ raise ValueError(error_text)
129
+
130
+ # If "all" models are requested, update requested list to include available models
131
+ if "all" in models_requested:
132
+ models_requested = available_models + [m for m in models_requested if m != "all"]
133
+
134
+ # If "ensemble" modeling is requested, use custom list of ensemble models
135
+ if "ensemble" in models_requested:
136
+ print("Running ensemble tide modelling")
137
+ ensemble_models = (
138
+ ensemble_models
139
+ if ensemble_models is not None
140
+ else [
141
+ "EOT20",
142
+ "FES2012",
143
+ "FES2014_extrapolated",
144
+ "FES2022_extrapolated",
145
+ "GOT4.10",
146
+ "GOT5.6_extrapolated",
147
+ "TPXO10-atlas-v2-nc",
148
+ "TPXO8-atlas-nc",
149
+ "TPXO9-atlas-v5-nc",
150
+ ]
151
+ )
152
+
153
+ # Error if any ensemble models are not available in `directory`
154
+ if not all(m in available_models for m in ensemble_models):
155
+ error_text = (
156
+ f"One or more of the requested ensemble models are not available in `{directory}`:\n"
157
+ f"{ensemble_models}\n\n"
158
+ f"The following models are available in `{directory}`:\n"
159
+ f"{available_models}"
160
+ )
161
+ raise ValueError(error_text)
162
+
163
+ # Return set of all ensemble plus any other requested models
164
+ models_to_process = sorted(list(set(ensemble_models + [m for m in models_requested if m != "ensemble"])))
165
+
166
+ # Otherwise, models to process are the same as those requested
167
+ else:
168
+ models_to_process = models_requested
169
+
170
+ return models_to_process, models_requested, ensemble_models
171
+
172
+
173
+ def _clip_model_file(
174
+ nc: xr.Dataset,
175
+ bbox: BoundingBox,
176
+ ydim: str,
177
+ xdim: str,
178
+ ycoord: str,
179
+ xcoord: str,
180
+ ) -> xr.Dataset:
181
+ """
182
+ Clips tide model netCDF datasets to a bounding box.
183
+
184
+ If the bounding box crosses 0 degrees longitude (e.g. Greenwich),
185
+ the function will clip the dataset into two parts and concatenate
186
+ them along the x-dimension to create a continuous result.
187
+
188
+ Parameters
189
+ ----------
190
+ nc : xr.Dataset
191
+ Input tide model xarray dataset.
192
+ bbox : odc.geo.geom.BoundingBox
193
+ A BoundingBox object for clipping the dataset in EPSG:4326
194
+ degrees coordinates. For example:
195
+ `BoundingBox(left=108, bottom=-48, right=158, top=-6, crs='EPSG:4326')`
196
+ ydim : str
197
+ The name of the xarray dimension representing the y-axis.
198
+ Depending on the tide model, this may or may not contain
199
+ actual latitude values.
200
+ xdim : str
201
+ The name of the xarray dimension representing the x-axis.
202
+ Depending on the tide model, this may or may not contain
203
+ actual longitude values.
204
+ ycoord : str
205
+ The name of the coordinate, variable or dimension containing
206
+ actual latitude values used for clipping the data.
207
+ xcoord : str
208
+ The name of the coordinate, variable or dimension containing
209
+ actual longitude values used for clipping the data.
210
+
211
+ Returns
212
+ -------
213
+ xr.Dataset
214
+ A dataset clipped to the specified bounding box, with
215
+ appropriate adjustments if the bounding box crosses 0
216
+ degrees longitude.
217
+
218
+ Examples
219
+ --------
220
+ >>> nc = xr.open_dataset("GOT5.5/ocean_tides/2n2.nc")
221
+ >>> bbox = BoundingBox(left=108, bottom=-48, right=158, top=-6, crs='EPSG:4326')
222
+ >>> clipped_nc = _clip_model_file(nc, bbox, xdim="lon", ydim="lat", ycoord="latitude", xcoord="longitude")
223
+ """
224
+
225
+ # Extract x and y coords from xarray and load into memory
226
+ xcoords = nc[xcoord].compute()
227
+ ycoords = nc[ycoord].compute()
228
+
229
+ # If data falls within 0-360 degree bounds, then clip directly
230
+ if (bbox.left >= 0) & (bbox.right <= 360):
231
+ nc_clipped = nc.sel({
232
+ ydim: (ycoords >= bbox.bottom) & (ycoords <= bbox.top),
233
+ xdim: (xcoords >= bbox.left) & (xcoords <= bbox.right),
234
+ })
235
+
236
+ # If bbox crosses zero longitude, extract left and right
237
+ # separately and then combine into one concatenated dataset
238
+ elif (bbox.left < 0) & (bbox.right > 0):
239
+ # Convert longitudes to 0-360 range
240
+ left = bbox.left % 360
241
+ right = bbox.right % 360
242
+
243
+ # Extract data from left of 0 longitude, and convert lon
244
+ # coords to -180 to 0 range to enable continuous interpolation
245
+ # across 0 boundary
246
+ nc_left = nc.sel({
247
+ ydim: (ycoords >= bbox.bottom) & (ycoords <= bbox.top),
248
+ xdim: (xcoords >= left) & (xcoords <= 360),
249
+ }).assign({xcoord: lambda x: x[xcoord] - 360})
250
+
251
+ # Convert additional lon variables for TXPO
252
+ if "lon_v" in nc_left:
253
+ nc_left = nc_left.assign({
254
+ "lon_v": lambda x: x["lon_v"] - 360,
255
+ "lon_u": lambda x: x["lon_u"] - 360,
256
+ })
257
+
258
+ # Extract data to right of 0 longitude
259
+ nc_right = nc.sel({
260
+ ydim: (ycoords >= bbox.bottom) & (ycoords <= bbox.top),
261
+ xdim: (xcoords > 0) & (xcoords <= right),
262
+ })
263
+
264
+ # Combine left and right data along x dimension
265
+ nc_clipped = xr.concat([nc_left, nc_right], dim=xdim)
266
+
267
+ # Hack fix to remove expanded x dim on lat variables issue
268
+ # for TPXO data; remove x dim by selecting the first obs
269
+ for i in ["lat_z", "lat_v", "lat_u", "con"]:
270
+ try:
271
+ nc_clipped[i] = nc_clipped[i].isel(nx=0)
272
+ except:
273
+ pass
274
+
275
+ return nc_clipped
276
+
277
+
278
+ def clip_models(
279
+ input_directory: str | os.PathLike,
280
+ output_directory: str | os.PathLike,
281
+ bbox: tuple[float, float, float, float],
282
+ model: list | None = None,
283
+ buffer: float = 5,
284
+ overwrite: bool = False,
285
+ ):
286
+ """
287
+ Clip NetCDF-format ocean tide models to a bounding box.
288
+
289
+ This function identifies all NetCDF-format tide models in a
290
+ given input directory, including "ATLAS-netcdf" (e.g. TPXO9-atlas-nc),
291
+ "FES-netcdf" (e.g. FES2022, EOT20), and "GOT-netcdf" (e.g. GOT5.5)
292
+ format files. Files for each model are then clipped to the extent of
293
+ the provided bounding box, handling model-specific file structures.
294
+ After each model is clipped, the result is exported to the output
295
+ directory and verified with `pyTMD` to ensure the clipped data is
296
+ suitable for tide modelling.
297
+
298
+ For instructions on accessing and downloading tide models, see:
299
+ <https://geoscienceaustralia.github.io/eo-tides/setup/>
300
+
301
+ Parameters
302
+ ----------
303
+ input_directory : str or os.PathLike
304
+ Path to directory containing input NetCDF-format tide model files.
305
+ output_directory : str or os.PathLike
306
+ Path to directory where clipped NetCDF files will be exported.
307
+ bbox : tuple of float
308
+ Bounding box for clipping the tide models in EPSG:4326 degrees
309
+ coordinates, specified as `(left, bottom, right, top)`.
310
+ model : str or list of str, optional
311
+ The tide model (or models) to clip. Defaults to None, which
312
+ will automatically identify and clip all NetCDF-format models
313
+ in the input directly.
314
+ buffer : float, optional
315
+ Buffer distance (in degrees) added to the bounding box to provide
316
+ sufficient data on edges of study area. Defaults to 5 degrees.
317
+ overwrite : bool, optional
318
+ If True, overwrite existing files in the output directory.
319
+ Defaults to False.
320
+
321
+ Examples
322
+ --------
323
+ >>> clip_models(
324
+ ... input_directory="tide_models/",
325
+ ... output_directory="tide_models_clipped/",
326
+ ... bbox=(-8.968392, 50.070574, 2.447160, 59.367122),
327
+ ... )
328
+ """
329
+
330
+ # Get input and output paths
331
+ input_directory = _set_directory(input_directory)
332
+ output_directory = pathlib.Path(output_directory)
333
+
334
+ # Prepare bounding box
335
+ bbox = odc.geo.geom.BoundingBox(*bbox, crs="EPSG:4326").buffered(buffer)
336
+
337
+ # Identify NetCDF models
338
+ model_database = load_database()["elevation"]
339
+ netcdf_formats = ["ATLAS-netcdf", "FES-netcdf", "GOT-netcdf"]
340
+ netcdf_models = {k for k, v in model_database.items() if v["format"] in netcdf_formats}
341
+
342
+ # Identify subset of available and requested NetCDF models
343
+ available_models, _ = list_models(directory=input_directory, show_available=False, show_supported=False)
344
+ requested_models = list(np.atleast_1d(model)) if model is not None else available_models
345
+ available_netcdf_models = list(set(available_models) & set(requested_models) & set(netcdf_models))
346
+
347
+ # Raise error if no valid models found
348
+ if len(available_netcdf_models) == 0:
349
+ raise ValueError(f"No valid NetCDF models found in {input_directory}.")
350
+
351
+ # If model list is provided,
352
+ print(f"Preparing to clip suitable NetCDF models: {available_netcdf_models}\n")
353
+
354
+ # Loop through suitable models and export
355
+ for m in available_netcdf_models:
356
+ # Get model file and grid file list if they exist
357
+ model_files = model_database[m].get("model_file", [])
358
+ grid_file = model_database[m].get("grid_file", [])
359
+
360
+ # Convert to list if strings and combine
361
+ model_files = model_files if isinstance(model_files, list) else [model_files]
362
+ grid_file = grid_file if isinstance(grid_file, list) else [grid_file]
363
+ all_files = model_files + grid_file
364
+
365
+ # Loop through each model file and clip
366
+ for file in tqdm(all_files, desc=f"Clipping {m}"):
367
+ # Skip if it exists in output directory
368
+ if (output_directory / file).exists() and not overwrite:
369
+ continue
370
+
371
+ # Load model file
372
+ nc = xr.open_mfdataset(input_directory / file)
373
+
374
+ # Open file and clip according to model
375
+ if m in (
376
+ "GOT5.5",
377
+ "GOT5.5_load",
378
+ "GOT5.5_extrapolated",
379
+ "GOT5.5D",
380
+ "GOT5.5D_extrapolated",
381
+ "GOT5.6",
382
+ "GOT5.6_extrapolated",
383
+ ):
384
+ nc_clipped = _clip_model_file(
385
+ nc,
386
+ bbox,
387
+ xdim="lon",
388
+ ydim="lat",
389
+ ycoord="latitude",
390
+ xcoord="longitude",
391
+ )
392
+
393
+ elif m in ("HAMTIDE11",):
394
+ nc_clipped = _clip_model_file(nc, bbox, xdim="LON", ydim="LAT", ycoord="LAT", xcoord="LON")
395
+
396
+ elif m in (
397
+ "EOT20",
398
+ "EOT20_load",
399
+ "FES2012",
400
+ "FES2014",
401
+ "FES2014_extrapolated",
402
+ "FES2014_load",
403
+ "FES2022",
404
+ "FES2022_extrapolated",
405
+ "FES2022_load",
406
+ ):
407
+ nc_clipped = _clip_model_file(nc, bbox, xdim="lon", ydim="lat", ycoord="lat", xcoord="lon")
408
+
409
+ elif m in (
410
+ "TPXO8-atlas-nc",
411
+ "TPXO9-atlas-nc",
412
+ "TPXO9-atlas-v2-nc",
413
+ "TPXO9-atlas-v3-nc",
414
+ "TPXO9-atlas-v4-nc",
415
+ "TPXO9-atlas-v5-nc",
416
+ "TPXO10-atlas-v2-nc",
417
+ ):
418
+ nc_clipped = _clip_model_file(
419
+ nc,
420
+ bbox,
421
+ xdim="nx",
422
+ ydim="ny",
423
+ ycoord="lat_z",
424
+ xcoord="lon_z",
425
+ )
426
+
427
+ else:
428
+ raise Exception(f"Model {m} not supported")
429
+
430
+ # Create directory and export
431
+ (output_directory / file).parent.mkdir(parents=True, exist_ok=True)
432
+ nc_clipped.to_netcdf(output_directory / file, mode="w")
433
+
434
+ # Verify that models are ready
435
+ pytmd_model(directory=output_directory).elevation(m=m).verify
436
+ print(" ✅ Clipped model exported and verified")
437
+
438
+ print(f"\nOutputs exported to {output_directory}")
439
+ list_models(directory=output_directory, show_available=True, show_supported=False)
440
+
441
+
442
+ def list_models(
443
+ directory: str | os.PathLike | None = None,
444
+ show_available: bool = True,
445
+ show_supported: bool = True,
446
+ raise_error: bool = False,
447
+ ) -> tuple[list[str], list[str]]:
448
+ """
449
+ List all tide models available for tide modelling.
450
+
451
+ This function scans the specified tide model directory
452
+ and returns a list of models that are available in the
453
+ directory as well as the full list of all models supported
454
+ by `eo-tides` and `pyTMD`.
455
+
456
+ For instructions on setting up tide models, see:
457
+ <https://geoscienceaustralia.github.io/eo-tides/setup/>
458
+
459
+ Parameters
460
+ ----------
461
+ directory : str, optional
462
+ The directory containing tide model data files. If no path is
463
+ provided, this will default to the environment variable
464
+ `EO_TIDES_TIDE_MODELS` if set, or raise an error if not.
465
+ Tide modelling files should be stored in sub-folders for each
466
+ model that match the structure required by `pyTMD`
467
+ (<https://geoscienceaustralia.github.io/eo-tides/setup/>).
468
+ show_available : bool, optional
469
+ Whether to print a list of locally available models.
470
+ show_supported : bool, optional
471
+ Whether to print a list of all supported models, in
472
+ addition to models available locally.
473
+ raise_error : bool, optional
474
+ If True, raise an error if no available models are found.
475
+ If False, raise a warning.
476
+
477
+ Returns
478
+ -------
479
+ available_models : list of str
480
+ A list of all tide models available within `directory`.
481
+ supported_models : list of str
482
+ A list of all tide models supported by `eo-tides`.
483
+ """
484
+ init() # Initialize colorama
485
+
486
+ # Set tide modelling files directory. If no custom path is
487
+ # provided, try global environment variable.
488
+ directory = _set_directory(directory)
489
+
490
+ # Get full list of supported models from pyTMD database
491
+ model_database = load_database()["elevation"]
492
+ supported_models = list(model_database.keys())
493
+
494
+ # Extract expected model paths
495
+ expected_paths = {}
496
+ for m in supported_models:
497
+ model_file = model_database[m]["model_file"]
498
+
499
+ # Handle GOT5.6 differently to ensure we test for presence of GOT5.6 constituents
500
+ if m in ("GOT5.6", "GOT5.6_extrapolated"):
501
+ model_file = [file for file in model_file if "GOT5.6" in file][0]
502
+ else:
503
+ model_file = model_file[0] if isinstance(model_file, list) else model_file
504
+
505
+ # Add path to dict
506
+ expected_paths[m] = str(directory / pathlib.Path(model_file).expanduser().parent)
507
+
508
+ # Define column widths
509
+ status_width = 4 # Width for emoji
510
+ name_width = max(len(name) for name in supported_models)
511
+ path_width = max(len(path) for path in expected_paths.values())
512
+
513
+ # Print list of supported models, marking available and
514
+ # unavailable models and appending available to list
515
+ if show_available or show_supported:
516
+ total_width = min(status_width + name_width + path_width + 6, 80)
517
+ print("─" * total_width)
518
+ print(f"{'󠀠🌊':^{status_width}} | {'Model':<{name_width}} | {'Expected path':<{path_width}}")
519
+ print("─" * total_width)
520
+
521
+ available_models = []
522
+ for m in supported_models:
523
+ try:
524
+ model_file = pytmd_model(directory=directory).elevation(m=m)
525
+ available_models.append(m)
526
+
527
+ if show_available:
528
+ # Mark available models with a green tick
529
+ status = "✅"
530
+ print(f"{status:^{status_width}}│ {m:<{name_width}} │ {expected_paths[m]:<{path_width}}")
531
+ except FileNotFoundError:
532
+ if show_supported:
533
+ # Mark unavailable models with a red cross
534
+ status = "❌"
535
+ print(
536
+ f"{status:^{status_width}}│ {Style.DIM}{m:<{name_width}} │ {expected_paths[m]:<{path_width}}{Style.RESET_ALL}"
537
+ )
538
+
539
+ if show_available or show_supported:
540
+ print("─" * total_width)
541
+
542
+ # Print summary
543
+ print(f"\n{Style.BRIGHT}Summary:{Style.RESET_ALL}")
544
+ print(f"Available models: {len(available_models)}/{len(supported_models)}")
545
+
546
+ # Raise error or warning if no models are available
547
+ if not available_models:
548
+ warning_msg = textwrap.dedent(
549
+ f"""
550
+ No valid tide models are available in `{directory}`.
551
+ Are you sure you have provided the correct `directory` path, or set the
552
+ `EO_TIDES_TIDE_MODELS` environment variable to point to the location of your
553
+ tide model directory?
554
+ """
555
+ ).strip()
556
+
557
+ if raise_error:
558
+ raise Exception(warning_msg)
559
+ else:
560
+ warnings.warn(warning_msg, UserWarning)
561
+
562
+ # Return list of available and supported models
563
+ return available_models, supported_models
564
+
565
+
566
+ def idw(
567
+ input_z,
568
+ input_x,
569
+ input_y,
570
+ output_x,
571
+ output_y,
572
+ p=1,
573
+ k=10,
574
+ max_dist=None,
575
+ k_min=1,
576
+ epsilon=1e-12,
577
+ ):
578
+ """Perform Inverse Distance Weighting (IDW) interpolation.
579
+
580
+ This function performs fast IDW interpolation by creating a KDTree
581
+ from the input coordinates then uses it to find the `k` nearest
582
+ neighbors for each output point. Weights are calculated based on the
583
+ inverse distance to each neighbor, with weights descreasing with
584
+ increasing distance.
585
+
586
+ Code inspired by: <https://github.com/DahnJ/REM-xarray>
587
+
588
+ Parameters
589
+ ----------
590
+ input_z : array-like
591
+ Array of values at the input points. This can be either a
592
+ 1-dimensional array, or a 2-dimensional array where each column
593
+ (axis=1) represents a different set of values to be interpolated.
594
+ input_x : array-like
595
+ Array of x-coordinates of the input points.
596
+ input_y : array-like
597
+ Array of y-coordinates of the input points.
598
+ output_x : array-like
599
+ Array of x-coordinates where the interpolation is to be computed.
600
+ output_y : array-like
601
+ Array of y-coordinates where the interpolation is to be computed.
602
+ p : int or float, optional
603
+ Power function parameter defining how rapidly weightings should
604
+ decrease as distance increases. Higher values of `p` will cause
605
+ weights for distant points to decrease rapidly, resulting in
606
+ nearby points having more influence on predictions. Defaults to 1.
607
+ k : int, optional
608
+ Number of nearest neighbors to use for interpolation. `k=1` is
609
+ equivalent to "nearest" neighbour interpolation. Defaults to 10.
610
+ max_dist : int or float, optional
611
+ Restrict neighbouring points to less than this distance.
612
+ By default, no distance limit is applied.
613
+ k_min : int, optional
614
+ If `max_dist` is provided, some points may end up with less than
615
+ `k` nearest neighbours, potentially producing less reliable
616
+ interpolations. Set `k_min` to set any points with less than
617
+ `k_min` neighbours to NaN. Defaults to 1.
618
+ epsilon : float, optional
619
+ Small value added to distances to prevent division by zero
620
+ errors in the case that output coordinates are identical to
621
+ input coordinates. Defaults to 1e-12.
622
+
623
+ Returns
624
+ -------
625
+ interp_values : numpy.ndarray
626
+ Interpolated values at the output coordinates. If `input_z` is
627
+ 1-dimensional, `interp_values` will also be 1-dimensional. If
628
+ `input_z` is 2-dimensional, `interp_values` will have the same
629
+ number of rows as `input_z`, with each column (axis=1)
630
+ representing interpolated values for one set of input data.
631
+
632
+ Examples
633
+ --------
634
+ >>> input_z = [1, 2, 3, 4, 5]
635
+ >>> input_x = [0, 1, 2, 3, 4]
636
+ >>> input_y = [0, 1, 2, 3, 4]
637
+ >>> output_x = [0.5, 1.5, 2.5]
638
+ >>> output_y = [0.5, 1.5, 2.5]
639
+ >>> idw(input_z, input_x, input_y, output_x, output_y, k=2)
640
+ array([1.5, 2.5, 3.5])
641
+
642
+ """
643
+ # Convert to numpy arrays
644
+ input_x = np.atleast_1d(input_x)
645
+ input_y = np.atleast_1d(input_y)
646
+ input_z = np.atleast_1d(input_z)
647
+ output_x = np.atleast_1d(output_x)
648
+ output_y = np.atleast_1d(output_y)
649
+
650
+ # Verify input and outputs have matching lengths
651
+ if not (input_z.shape[0] == len(input_x) == len(input_y)):
652
+ raise ValueError("All of `input_z`, `input_x` and `input_y` must be the same length.")
653
+ if not (len(output_x) == len(output_y)):
654
+ raise ValueError("Both `output_x` and `output_y` must be the same length.")
655
+
656
+ # Verify k is smaller than total number of points, and non-zero
657
+ if k > input_z.shape[0]:
658
+ raise ValueError(
659
+ f"The requested number of nearest neighbours (`k={k}`) "
660
+ f"is smaller than the total number of points ({input_z.shape[0]}).",
661
+ )
662
+ if k == 0:
663
+ raise ValueError("Interpolation based on `k=0` nearest neighbours is not valid.")
664
+
665
+ # Create KDTree to efficiently find nearest neighbours
666
+ points_xy = np.column_stack((input_y, input_x))
667
+ tree = KDTree(points_xy)
668
+
669
+ # Determine nearest neighbours and distances to each
670
+ grid_stacked = np.column_stack((output_y, output_x))
671
+ distances, indices = tree.query(grid_stacked, k=k, workers=-1)
672
+
673
+ # If k == 1, add an additional axis for consistency
674
+ if k == 1:
675
+ distances = distances[..., np.newaxis]
676
+ indices = indices[..., np.newaxis]
677
+
678
+ # Add small epsilon to distances to prevent division by zero errors
679
+ # if output coordinates are the same as input coordinates
680
+ distances = np.maximum(distances, epsilon)
681
+
682
+ # Set distances above max to NaN if specified
683
+ if max_dist is not None:
684
+ distances[distances > max_dist] = np.nan
685
+
686
+ # Calculate weights based on distance to k nearest neighbours.
687
+ weights = 1 / np.power(distances, p)
688
+ weights = weights / np.nansum(weights, axis=1).reshape(-1, 1)
689
+
690
+ # 1D case: Compute weighted sum of input_z values for each output point
691
+ if input_z.ndim == 1:
692
+ interp_values = np.nansum(weights * input_z[indices], axis=1)
693
+
694
+ # 2D case: Compute weighted sum for each set of input_z values
695
+ # weights[..., np.newaxis] adds a dimension for broadcasting
696
+ else:
697
+ interp_values = np.nansum(
698
+ weights[..., np.newaxis] * input_z[indices],
699
+ axis=1,
700
+ )
701
+
702
+ # Set any points with less than `k_min` valid weights to NaN
703
+ interp_values[np.isfinite(weights).sum(axis=1) < k_min] = np.nan
704
+
705
+ return interp_values