ocstrack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,429 @@
1
+ from typing import Optional, Tuple, Union
2
+ import logging
3
+ import numpy as np
4
+ import xarray as xr
5
+ from tqdm import tqdm
6
+
7
+ from Model.model import SCHISM
8
+ from Satellite.satellite import SatelliteData
9
+ from Collocation.temporal import temporal_nearest, temporal_interpolated
10
+ from Collocation.spatial import GeocentricSpatialLocator, inverse_distance_weights
11
+ from Collocation.output import make_collocated_nc
12
+
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format="%(asctime)s - %(levelname)s - %(message)s",
16
+ handlers=[logging.StreamHandler()]
17
+ )
18
+ _logger = logging.getLogger(__name__)
19
+
20
+
21
+ class Collocate:
22
+ """Model–satellite collocation engine
23
+
24
+ This is the mains class.
25
+ It handles the spatial and temporal collocation of satellite
26
+ altimetry data (e.g., significant wave height, sea level anomaly (TBD))
27
+ with unstructured model outputs (e.g., SCHISM). It supports both
28
+ nearest-neighbor (in time) and temporally interpolated collocation strategies.
29
+
30
+ Methods
31
+ -------
32
+ run(output_path=None)
33
+ Run the collocation over all model files and return a combined
34
+ xarray.Dataset.
35
+
36
+ Notes
37
+ -----
38
+ Collocation is performed using:
39
+ - Nearest N spatial nodes (with inverse distance weighting)
40
+ - Radius (meters) based search
41
+ - Nearest or interpolated temporal matching
42
+ - Optional distance-to-coast dataset for filtering/post-processing
43
+
44
+ Automatically infers time_buffer from model time step if not provided.
45
+ """
46
+ def __init__(self,
47
+ model_run: SCHISM,
48
+ satellite: SatelliteData,
49
+ dist_coast: Optional[xr.Dataset] = None,
50
+ n_nearest: Optional[int] = None,
51
+ search_radius: Optional[float] = None,
52
+ time_buffer: Optional[np.timedelta64] = None,
53
+ weight_power: float = 1.0,
54
+ temporal_interp: bool = False) -> None:
55
+ """
56
+ Parameters
57
+ ----------
58
+ model_run : SCHISM
59
+ Model object containing grid, file paths, and data access
60
+ satellite : SatelliteData
61
+ Satellite data wrapper providing SWH, SLA, etc.
62
+ dist_coast : xarray.Dataset, optional
63
+ Optional dataset containing distance-to-coast info
64
+ n_nearest : int, optional
65
+ Number of nearest spatial model nodes to use
66
+ search_radius : float, optional
67
+ Radius (in meters) to search for spatial neighbors.
68
+ If provided, overwrite n_nearest and uses radius-based spatial matching.
69
+ time_buffer : np.timedelta64, optional
70
+ Temporal search buffer; if None, inferred from model timestep
71
+ weight_power : float, default=1.0
72
+ Power exponent for inverse distance weighting
73
+ temporal_interp : bool, default=False
74
+ Whether to perform linear temporal interpolation
75
+ """
76
+ self.model = model_run
77
+ self.sat = satellite
78
+ self.dist_coast = dist_coast["distcoast"] if dist_coast is not None else None
79
+ self.n_nearest = n_nearest
80
+ self.search_radius = search_radius
81
+ self.weight_power = weight_power
82
+ self.temporal_interp = temporal_interp
83
+
84
+ if search_radius is not None and n_nearest is not None:
85
+ _logger.warning("Both search_radius and n_nearest provided; ignoring n_nearest and using radius-based spatial matching.")
86
+ elif search_radius is None and n_nearest is None:
87
+ raise ValueError("Specify either 'n_nearest' or 'search_radius'")
88
+
89
+ # Set locator
90
+ _logger.info("Initializing 3D Geocentric (WGS 84) spatial locator.")
91
+ self.locator = GeocentricSpatialLocator(
92
+ self.model.mesh_x, self.model.mesh_y, model_height=None
93
+ )
94
+
95
+ # If radius search is on, nullify n_nearest
96
+ if search_radius is not None:
97
+ self.n_nearest = None # Prevent accidental use
98
+
99
+ # Automatically estimate time buffer if not provided
100
+ if time_buffer is None:
101
+ example_file = self.model.files[0]
102
+ times = self.model.load_variable(example_file)["time"].values
103
+
104
+ if len(times) < 2:
105
+ raise ValueError("Cannot infer time_buffer: less than two model timesteps.")
106
+
107
+ # Calculate timestep and use half of it as buffer
108
+ timestep = times[1] - times[0] # Assumes constant step
109
+ self.time_buffer = timestep / 2
110
+ _logger.info(f"Inferred time_buffer as half timestep: {self.time_buffer}")
111
+ else:
112
+ self.time_buffer = time_buffer
113
+
114
+ def _extract_model_values(self,
115
+ m_var: xr.DataArray,
116
+ times_or_inds: Union[np.ndarray, Tuple[np.ndarray, np.ndarray, np.ndarray]],
117
+ nodes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
118
+ """
119
+ Extract model variable values and corresponding depths at given times and nodes.
120
+
121
+ Parameters
122
+ ----------
123
+ m_var : xarray.DataArray
124
+ Model variable to extract from (e.g. significant wave height)
125
+ times_or_inds : tuple or list
126
+ Time indices or interpolation args (ib, ia, wts)
127
+ nodes : np.ndarray
128
+ Node indices of nearest spatial neighbors
129
+
130
+ Returns
131
+ -------
132
+ Tuple[np.ndarray, np.ndarray]
133
+ Extracted model values and node depths
134
+ """
135
+ model_data = m_var.values
136
+ depths = self.model.mesh_depth
137
+
138
+ values, dpts = [], []
139
+
140
+ if self.temporal_interp:
141
+ ib, ia, wts = times_or_inds
142
+ for i, nd in enumerate(nodes):
143
+ v0 = model_data[ib[i], nd]
144
+ v1 = model_data[ia[i], nd]
145
+ values.append(v0 * (1 - wts[i]) + v1 * wts[i])
146
+ dpts.append(depths[nd])
147
+ else:
148
+ for i, (t_idx, nd) in enumerate(zip(times_or_inds, nodes)):
149
+ t = m_var["time"].values[t_idx]
150
+ values.append(m_var.sel(time=t, nSCHISM_hgrid_node=nd).values)
151
+ dpts.append(depths[nd])
152
+
153
+ return np.array(values), np.array(dpts)
154
+
155
+ def _coast_distance(self,
156
+ lats: np.ndarray,
157
+ lons: np.ndarray) -> np.ndarray:
158
+ """
159
+ Get distance to coast for given lat/lon points using optional dataset.
160
+
161
+ Parameters
162
+ ----------
163
+ lats : array-like
164
+ Latitudes of satellite observations
165
+ lons : array-like
166
+ Longitudes of satellite observations
167
+
168
+ Returns
169
+ -------
170
+ np.ndarray
171
+ Interpolated coastal distances, or NaNs if unavailable
172
+ """
173
+ if self.dist_coast is None:
174
+ return np.full_like(lats, fill_value=np.nan, dtype=float)
175
+ return self.dist_coast.sel(
176
+ latitude=xr.DataArray(lats, dims="points"),
177
+ longitude=xr.DataArray(lons, dims="points"),
178
+ method="nearest",
179
+ ).values
180
+
181
+ def _get_sat_height(self, sat_sub: xr.Dataset) -> np.ndarray:
182
+ """
183
+ Extracts satellite height/altitude from the dataset.
184
+ Defaults to 0m with a warning if not found.
185
+ """
186
+ if 'height' in sat_sub:
187
+ return sat_sub["height"].values
188
+ if 'altitude' in sat_sub:
189
+ return sat_sub["altitude"].values
190
+
191
+ _logger.warning("No 'height' or 'altitude' in satellite data. "
192
+ "Defaulting to 0m for geocentric query. "
193
+ "This may be inaccurate for altimeter data.")
194
+ return np.zeros_like(sat_sub["lon"].values)
195
+
196
+ def _collocate_with_radius(self, sat_sub, m_var, time_args):
197
+ """
198
+ Collocate satellite observations with model output using a spatial search radius.
199
+ This is more challendi
200
+
201
+ Parameters
202
+ ----------
203
+ - sat_sub (xarray.Dataset): Subset of satellite data.
204
+ - m_var (str): Model variable name (e.g., 'sigWaveHeight').
205
+ - time_args (tuple or list): Time interpolation arguments or time indices.
206
+
207
+ Returns
208
+ -------
209
+ - dict: A dictionary containing collocated model variables:
210
+ * model_swh: 2D array [obs, nearest_nodes]
211
+ * model_dpt: 2D array [obs, nearest_nodes]
212
+ * dist_deltas: 2D array [obs, nearest_nodes] (distances)
213
+ * node_ids: 2D array [obs, nearest_nodes]
214
+ * model_swh_weighted: 1D array of weighted model SWH [obs]
215
+ * bias_raw: 1D array of unweighted biases [obs]
216
+ * bias_weighted: 1D array of weighted biases [obs]
217
+
218
+ Notes
219
+ -----
220
+ Padding is applied to all per-observation arrays to ensure they can be stacked into
221
+ uniform 2D arrays, even though the number of nearest model nodes may differ per observation.
222
+ This ensures consistent array dimensions and enables construction of an xarray.Dataset later
223
+ dimension mismatches.
224
+ """
225
+ lons = sat_sub["lon"].values
226
+ lats = sat_sub["lat"].values
227
+ heights = self._get_sat_height(sat_sub)
228
+
229
+ all_dists, all_nodes = self.locator.query_radius(
230
+ lons, lats, heights, radius_m=self.search_radius
231
+ )
232
+
233
+ flat_nodes = []
234
+ flat_ib, flat_ia, flat_wt = [], [], []
235
+ obs_lens = []
236
+
237
+ for i, (nodes, dists) in enumerate(zip(all_nodes, all_dists)):
238
+ obs_lens.append(len(nodes))
239
+ if len(nodes) == 0:
240
+ continue # no nodes found — handled after extraction
241
+
242
+ if self.temporal_interp:
243
+ ib, ia, wts = time_args
244
+ flat_ib.extend([ib[i]] * len(nodes))
245
+ flat_ia.extend([ia[i]] * len(nodes))
246
+ flat_wt.extend([wts[i]] * len(nodes))
247
+ else:
248
+ flat_ib.extend([time_args[i]] * len(nodes)) # just time index
249
+
250
+ flat_nodes.extend(nodes)
251
+
252
+ # Handle case where no nodes were found for any obs
253
+ if not flat_nodes:
254
+ n_obs = len(lons)
255
+ nan_arr = np.full((n_obs, 1), np.nan)
256
+ return {
257
+ "model_swh": nan_arr,
258
+ "model_dpt": nan_arr,
259
+ "dist_deltas": nan_arr,
260
+ "node_ids": nan_arr,
261
+ "model_swh_weighted": np.full(n_obs, np.nan),
262
+ "bias_raw": np.full(n_obs, np.nan),
263
+ "bias_weighted": np.full(n_obs, np.nan),
264
+ }
265
+
266
+ # Perform extraction once
267
+ if self.temporal_interp:
268
+ m_vals, m_dpts = self._extract_model_values(
269
+ m_var, (np.array(flat_ib), np.array(flat_ia), np.array(flat_wt)), np.array(flat_nodes)
270
+ )
271
+ else:
272
+ m_vals, m_dpts = self._extract_model_values(
273
+ m_var, np.array(flat_ib), np.array(flat_nodes)
274
+ )
275
+
276
+ # Reshape into per-observation lists
277
+ def unflatten(arr, lens):
278
+ return np.split(arr, np.cumsum(lens)[:-1])
279
+
280
+ split_vals = unflatten(m_vals, obs_lens)
281
+ split_dpts = unflatten(m_dpts, obs_lens)
282
+ split_dists = unflatten(np.concatenate([np.array(d) for d in all_dists if len(d) > 0]), obs_lens)
283
+ split_nodes = unflatten(np.array(flat_nodes), obs_lens)
284
+
285
+ # Handle obs with no neighbors
286
+ def pad(arrs):
287
+ max_len = max((len(a) for a in arrs), default=1)
288
+ return np.stack([
289
+ np.pad(a, (0, max_len - len(a)), constant_values=np.nan) for a in arrs
290
+ ])
291
+
292
+ # Generate weights and weighted values
293
+ weights_list = [inverse_distance_weights(d[None, :], self.weight_power)[0]
294
+ if len(d) > 0 else np.array([np.nan])
295
+ for d in split_dists]
296
+
297
+ weighted_vals = [np.sum(v * w) if len(v) > 0 else np.nan
298
+ for v, w in zip(split_vals, weights_list)]
299
+
300
+ return {
301
+ "model_swh": pad(split_vals),
302
+ "model_dpt": pad(split_dpts),
303
+ "dist_deltas": pad(split_dists),
304
+ "node_ids": pad([a.astype(float) for a in split_nodes]),
305
+ "model_swh_weighted": np.array(weighted_vals),
306
+ "bias_raw": np.array([
307
+ np.nanmean(v) - s if len(v) > 0 else np.nan
308
+ for v, s in zip(split_vals, sat_sub["swh"].values)
309
+ ]),
310
+ "bias_weighted": np.array(weighted_vals) - sat_sub["swh"].values,
311
+ }
312
+
313
+ def _collocate_with_nearest(self, sat_sub, m_var, time_args):
314
+ """
315
+ Perform collocation using nearest-neighbor spatial search.
316
+
317
+ For each satellite observation, find a fixed number of nearest model nodes,
318
+ extract model values at relevant times (interpolated or nearest),
319
+ compute inverse-distance weights, and calculate weighted averages.
320
+
321
+ Parameters
322
+ ----------
323
+ sat_sub : xarray.Dataset
324
+ Subset of satellite observations to collocate.
325
+ m_var : xarray.DataArray
326
+ Model variable data for the current time slice.
327
+ time_args : tuple or np.ndarray
328
+ Temporal indices or interpolation arguments depending on temporal method.
329
+
330
+ Returns
331
+ -------
332
+ dict
333
+ Dictionary containing arrays for:
334
+ - model_swh: model values per neighbor and observation
335
+ - model_dpt: node depths
336
+ - dist_deltas: distances to neighbors
337
+ - node_ids: spatial node indices
338
+ - model_swh_weighted: weighted model values per observation
339
+ - bias_raw: difference between mean model and satellite values
340
+ - bias_weighted: difference between weighted model and satellite values
341
+ """
342
+ lons = sat_sub["lon"].values
343
+ lats = sat_sub["lat"].values
344
+ heights = self._get_sat_height(sat_sub)
345
+
346
+ dists, nodes = self.locator.query_nearest(
347
+ lons, lats, heights, k=self.n_nearest
348
+ )
349
+
350
+ m_vals, m_dpts = self._extract_model_values(m_var, time_args, nodes)
351
+ weights = inverse_distance_weights(dists, self.weight_power)
352
+ weighted = (m_vals * weights).sum(axis=1)
353
+
354
+ return {
355
+ "model_swh": m_vals,
356
+ "model_dpt": m_dpts,
357
+ "dist_deltas": dists,
358
+ "node_ids": nodes,
359
+ "model_swh_weighted": weighted,
360
+ "bias_raw": m_vals.mean(axis=1) - sat_sub["swh"].values,
361
+ "bias_weighted": weighted - sat_sub["swh"].values,
362
+ }
363
+
364
+ def run(self,
365
+ output_path: Optional[str] = None) -> xr.Dataset:
366
+ """
367
+ Run the full model–satellite collocation process over all model files.
368
+
369
+ This function iterates over all model output files, performs temporal and spatial
370
+ collocation of satellite data with model results, calculates weighted averages,
371
+ biases, and optionally writes the collocated results to a NetCDF file.
372
+
373
+ Parameters
374
+ ----------
375
+ output_path : str, optional
376
+ If provided, writes collocated output to NetCDF file
377
+
378
+ Returns
379
+ -------
380
+ xarray.Dataset
381
+ Dataset containing collocated satellite and model data
382
+ """
383
+ results = {k: [] for k in [
384
+ "time_sat", "lat_sat", "lon_sat", "source_sat",
385
+ "sat_swh", "sat_sla", "model_swh", "model_dpt",
386
+ "dist_deltas", "node_ids", "time_deltas",
387
+ "model_swh_weighted", "bias_raw", "bias_weighted"
388
+ ]}
389
+
390
+ include_coast = self.dist_coast is not None
391
+ if include_coast:
392
+ results["dist_coast"] = []
393
+
394
+ for path in tqdm(self.model.files, desc="Collocating..."):
395
+ m_var = self.model.load_variable(path)
396
+ m_times = m_var["time"].values
397
+
398
+ if self.temporal_interp:
399
+ sat_sub, ib, ia, wts, tdel = temporal_interpolated(self.sat.ds, m_times, self.time_buffer)
400
+ time_args = (ib, ia, wts)
401
+ else:
402
+ sat_sub, idx, tdel = temporal_nearest(self.sat.ds, m_times, self.time_buffer)
403
+ time_args = idx
404
+
405
+ if self.search_radius is not None:
406
+ spatial = self._collocate_with_radius(sat_sub, m_var, time_args)
407
+ else:
408
+ spatial = self._collocate_with_nearest(sat_sub, m_var, time_args)
409
+
410
+ results["time_sat"].append(sat_sub["time"].values)
411
+ results["lat_sat"].append(sat_sub["lat"].values)
412
+ results["lon_sat"].append(sat_sub["lon"].values)
413
+ results["source_sat"].append(sat_sub["source"].values)
414
+ results["sat_swh"].append(sat_sub["swh"].values)
415
+ results["sat_sla"].append(sat_sub["sla"].values)
416
+ results["time_deltas"].append(tdel)
417
+
418
+ for k in ["model_swh", "model_dpt", "dist_deltas", "node_ids", "model_swh_weighted", "bias_raw", "bias_weighted"]:
419
+ results[k].append(spatial[k])
420
+
421
+ if include_coast:
422
+ coast_d = self._coast_distance(sat_sub["lat"].values, sat_sub["lon"].values)
423
+ results["dist_coast"].append(coast_d)
424
+
425
+ n_neighbors = None if self.search_radius is not None else self.n_nearest
426
+ ds_out = make_collocated_nc(results, n_neighbors)
427
+ if output_path:
428
+ ds_out.to_netcdf(output_path)
429
+ return ds_out
@@ -0,0 +1,63 @@
1
+ import xarray as xr
2
+ import numpy as np
3
+
4
+
5
+ def get_max_neighbors(result_list):
6
+ """Get the maximum number of neighbors from a list of 2D arrays."""
7
+ return max(arr.shape[1] for arr in result_list if arr.ndim == 2)
8
+
9
+
10
+ def pad_arrays_to_max(arrays, max_cols):
11
+ """Pad 2D arrays with NaNs to have the same number of columns."""
12
+ padded = []
13
+ for arr in arrays:
14
+ if arr.shape[1] < max_cols:
15
+ pad_width = max_cols - arr.shape[1]
16
+ pad_arr = np.pad(arr, ((0, 0), (0, pad_width)), constant_values=np.nan)
17
+ padded.append(pad_arr)
18
+ else:
19
+ padded.append(arr[:, :max_cols]) # Optionally truncate
20
+ return padded
21
+
22
+
23
+ def make_collocated_nc(results: dict, n_nearest: int = None) -> xr.Dataset:
24
+ # Determine max neighbors from actual data if not explicitly provided
25
+ max_neighbors = get_max_neighbors(results["model_swh"]) if n_nearest is None else n_nearest
26
+
27
+ # Pad all neighbor-dependent arrays
28
+ model_swh = pad_arrays_to_max(results["model_swh"], max_neighbors)
29
+ model_dpt = pad_arrays_to_max(results["model_dpt"], max_neighbors)
30
+ dist_deltas = pad_arrays_to_max(results["dist_deltas"], max_neighbors)
31
+ node_ids = pad_arrays_to_max(results["node_ids"], max_neighbors)
32
+
33
+ data_vars = {
34
+ "lon": (["time"], np.concatenate(results["lon_sat"])),
35
+ "lat": (["time"], np.concatenate(results["lat_sat"])),
36
+ "sat_swh": (["time"], np.concatenate(results["sat_swh"])),
37
+ "sat_sla": (["time"], np.concatenate(results["sat_sla"])),
38
+ "model_swh": (["time", "nearest_nodes"], np.vstack(model_swh)),
39
+ "model_swh_weighted": (["time"], np.concatenate(results["model_swh_weighted"])),
40
+ "model_dpt": (["time", "nearest_nodes"], np.vstack(model_dpt)),
41
+ "dist_deltas": (["time", "nearest_nodes"], np.vstack(dist_deltas)),
42
+ "node_ids": (["time", "nearest_nodes"], np.vstack(node_ids)),
43
+ "time_deltas": (["time"], np.concatenate(results["time_deltas"])),
44
+ "bias_raw": (["time"], np.concatenate(results["bias_raw"])),
45
+ "bias_weighted": (["time"], np.concatenate(results["bias_weighted"])),
46
+ "source_sat": (["time"], np.concatenate(results["source_sat"])),
47
+ }
48
+
49
+ if "dist_coast" in results:
50
+ data_vars["dist_coast"] = (["time"], np.concatenate(results["dist_coast"]))
51
+
52
+ ds = xr.Dataset(
53
+ data_vars=data_vars,
54
+ coords={
55
+ "time": np.concatenate(results["time_sat"]),
56
+ "nearest_nodes": np.arange(max_neighbors),
57
+ },
58
+ attrs={
59
+ "Conventions": "CF-1.7",
60
+ "title": "CF-compliant Satellite vs Model SWH Dataset",
61
+ }
62
+ )
63
+ return ds
@@ -0,0 +1,175 @@
1
+ import numpy as np
2
+ from scipy.spatial import KDTree
3
+ from typing import List, Tuple
4
+
5
+
6
+ def lat_lon_to_cartesian_vec(latitude: np.ndarray,
7
+ longitude: np.ndarray,
8
+ height: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
9
+ """
10
+ Converts geodetic coordinates (latitude, longitude, height) to
11
+ geocentric Cartesian coordinates (X, Y, Z) using numpy.
12
+
13
+ This function uses the WGS 84 ellipsoid model.
14
+ Assumes latitude and longitude are in degrees, height in meters.
15
+ Returns (X, Y, Z) in meters.
16
+ """
17
+ # WGS 84 ellipsoid parameters
18
+ a = 6378137.0 # Semi-major axis (meters)
19
+ f = 1 / 298.257223563 # Flattening
20
+ e_sq = 2 * f - f**2 # Eccentricity squared (e^2)
21
+
22
+ lat_rad = np.radians(latitude)
23
+ lon_rad = np.radians(longitude)
24
+ # Calculate N (radius of curvature in the prime vertical)
25
+ n = a / np.sqrt(1 - e_sq * np.sin(lat_rad)**2)
26
+
27
+ # Calculate Cartesian coordinates
28
+ x = (n + height) * np.cos(lat_rad) * np.cos(lon_rad)
29
+ y = (n + height) * np.cos(lat_rad) * np.sin(lon_rad)
30
+ z = (n * (1 - e_sq) + height) * np.sin(lat_rad)
31
+
32
+ return x, y, z
33
+
34
+ def inverse_distance_weights(distances: np.ndarray,
35
+ power: float = 1.0) -> np.ndarray:
36
+ """
37
+ Compute inverse distance weights (IDW) with configurable exponent.
38
+
39
+ Parameters
40
+ ----------
41
+ distances : np.ndarray
42
+ Distance array to nearest neighbors, shape (N, k)
43
+ power : float, optional
44
+ Power exponent for distance weighting (default is 1.0).
45
+ Use 1.0 for linear, 2.0 for quadratic, etc.
46
+
47
+ Returns
48
+ -------
49
+ np.ndarray
50
+ Normalized inverse distance weights of shape (N, k)
51
+
52
+ Notes
53
+ -----
54
+ A small epsilon (1e-6) is used to avoid division by zero.
55
+ """
56
+ safe_distances = np.maximum(distances, 1e-6) #to avoid division by zero
57
+ weights = 1.0 / np.power(safe_distances, power)
58
+ return weights / weights.sum(axis=1, keepdims=True)
59
+
60
+ class GeocentricSpatialLocator:
61
+ """
62
+ KDTree-based spatial query engine using 3D Geocentric (WGS 84) coordinates.
63
+
64
+ Handles both nearest-neighbor and radius-based lookups between satellite
65
+ points and model grid nodes using a fast 3D KDTree built on
66
+ ECEF (Earth-Centered, Earth-Fixed) coordinates.
67
+ """
68
+
69
+ def __init__(self,
70
+ model_lon: np.ndarray,
71
+ model_lat: np.ndarray,
72
+ model_height: np.ndarray = None) -> None:
73
+ """
74
+ Parameters
75
+ ----------
76
+ model_lon : np.ndarray
77
+ Longitudes of model mesh nodes (degrees)
78
+ model_lat : np.ndarray
79
+ Latitudes of model mesh nodes (degrees)
80
+ model_height : np.ndarray, optional
81
+ Heights of model mesh nodes above ellipsoid (meters).
82
+ If None, defaults to 0 (on the ellipsoid surface).
83
+ """
84
+ if model_height is None:
85
+ model_height = np.zeros_like(model_lon)
86
+
87
+ x, y, z = lat_lon_to_cartesian_vec(model_lat, model_lon, model_height)
88
+ self.model_xyz = np.column_stack((x, y, z))
89
+ self.tree = KDTree(self.model_xyz)
90
+
91
+ def _get_query_points(self,
92
+ sat_lon: np.ndarray,
93
+ sat_lat: np.ndarray,
94
+ sat_height: np.ndarray) -> np.ndarray:
95
+ """Helper to convert satellite coordinates to 3D Cartesian."""
96
+ x_q, y_q, z_q = lat_lon_to_cartesian_vec(sat_lat, sat_lon, sat_height)
97
+ return np.column_stack((x_q, y_q, z_q))
98
+
99
+ def query_nearest(self,
100
+ sat_lon: np.ndarray,
101
+ sat_lat: np.ndarray,
102
+ sat_height: np.ndarray,
103
+ k: int = 3) -> Tuple[np.ndarray, np.ndarray]:
104
+ """
105
+ Query for the 'k' nearest model nodes to each satellite point.
106
+
107
+ Parameters
108
+ ----------
109
+ sat_lon : np.ndarray
110
+ Longitudes of satellite observations (degrees)
111
+ sat_lat : np.ndarray
112
+ Latitudes of satellite observations (degrees)
113
+ sat_height : np.ndarray
114
+ Heights/Altitudes of satellite observations (meters above ellipsoid)
115
+ k : int, optional
116
+ Number of nearest model neighbors (default is 3)
117
+
118
+ Returns
119
+ -------
120
+ tuple of np.ndarray
121
+ Distances (meters) and indices of nearest model nodes, shape (N, k)
122
+ """
123
+ query_points = self._get_query_points(sat_lon, sat_lat, sat_height)
124
+ distances, indices = self.tree.query(query_points, k=k)
125
+ return distances, indices
126
+
127
+ def query_radius(self,
128
+ sat_lon: np.ndarray,
129
+ sat_lat: np.ndarray,
130
+ sat_height: np.ndarray,
131
+ radius_m: float) -> Tuple[List[np.ndarray], List[np.ndarray]]:
132
+ """
133
+ Query for all model nodes within the 3D search radius.
134
+
135
+ Parameters
136
+ ----------
137
+ sat_lon : np.ndarray
138
+ Longitudes of satellite observations (degrees)
139
+ sat_lat : np.ndarray
140
+ Latitudes of satellite observations (degrees)
141
+ sat_height : np.ndarray
142
+ Heights/Altitudes of satellite observations (meters above ellipsoid)
143
+ radius_m : float
144
+ Search radius in meters
145
+
146
+ Returns
147
+ -------
148
+ tuple of list[np.ndarray], list[np.ndarray]
149
+ - Distances (in meters) to all matched model nodes per point
150
+ - Corresponding indices of model nodes per point
151
+ """
152
+ query_points = self._get_query_points(sat_lon, sat_lat, sat_height)
153
+
154
+ # Find indices of points within the radius
155
+ indices_list = self.tree.query_ball_point(query_points, r=radius_m)
156
+
157
+ distances_list = []
158
+
159
+ # Now calculate true distances for those indices
160
+ for i, node_inds in enumerate(indices_list):
161
+ if not node_inds:
162
+ distances_list.append(np.array([]))
163
+ indices_list[i] = np.array([]) # Ensure consistent type
164
+ continue
165
+
166
+ query_point = query_points[i]
167
+ node_points = self.model_xyz[node_inds]
168
+
169
+ # Calculate true 3D Euclidean distances
170
+ dists = np.linalg.norm(node_points - query_point, axis=1)
171
+
172
+ distances_list.append(dists)
173
+ indices_list[i] = np.array(node_inds) # Ensure consistent type
174
+
175
+ return distances_list, indices_list