nrl-tracker 0.22.5__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. {nrl_tracker-0.22.5.dist-info → nrl_tracker-1.8.0.dist-info}/METADATA +57 -10
  2. {nrl_tracker-0.22.5.dist-info → nrl_tracker-1.8.0.dist-info}/RECORD +86 -69
  3. pytcl/__init__.py +4 -3
  4. pytcl/assignment_algorithms/__init__.py +28 -0
  5. pytcl/assignment_algorithms/dijkstra_min_cost.py +184 -0
  6. pytcl/assignment_algorithms/gating.py +10 -10
  7. pytcl/assignment_algorithms/jpda.py +40 -40
  8. pytcl/assignment_algorithms/nd_assignment.py +379 -0
  9. pytcl/assignment_algorithms/network_flow.py +464 -0
  10. pytcl/assignment_algorithms/network_simplex.py +167 -0
  11. pytcl/assignment_algorithms/three_dimensional/assignment.py +3 -3
  12. pytcl/astronomical/__init__.py +104 -3
  13. pytcl/astronomical/ephemerides.py +14 -11
  14. pytcl/astronomical/reference_frames.py +865 -56
  15. pytcl/astronomical/relativity.py +6 -5
  16. pytcl/astronomical/sgp4.py +710 -0
  17. pytcl/astronomical/special_orbits.py +532 -0
  18. pytcl/astronomical/tle.py +558 -0
  19. pytcl/atmosphere/__init__.py +43 -1
  20. pytcl/atmosphere/ionosphere.py +512 -0
  21. pytcl/atmosphere/nrlmsise00.py +809 -0
  22. pytcl/clustering/dbscan.py +2 -2
  23. pytcl/clustering/gaussian_mixture.py +3 -3
  24. pytcl/clustering/hierarchical.py +15 -15
  25. pytcl/clustering/kmeans.py +4 -4
  26. pytcl/containers/__init__.py +24 -0
  27. pytcl/containers/base.py +219 -0
  28. pytcl/containers/cluster_set.py +12 -2
  29. pytcl/containers/covertree.py +26 -29
  30. pytcl/containers/kd_tree.py +94 -29
  31. pytcl/containers/rtree.py +200 -1
  32. pytcl/containers/vptree.py +21 -28
  33. pytcl/coordinate_systems/conversions/geodetic.py +272 -5
  34. pytcl/coordinate_systems/jacobians/jacobians.py +2 -2
  35. pytcl/coordinate_systems/projections/__init__.py +1 -1
  36. pytcl/coordinate_systems/projections/projections.py +2 -2
  37. pytcl/coordinate_systems/rotations/rotations.py +10 -6
  38. pytcl/core/__init__.py +18 -0
  39. pytcl/core/validation.py +333 -2
  40. pytcl/dynamic_estimation/__init__.py +26 -0
  41. pytcl/dynamic_estimation/gaussian_sum_filter.py +434 -0
  42. pytcl/dynamic_estimation/imm.py +14 -14
  43. pytcl/dynamic_estimation/kalman/__init__.py +30 -0
  44. pytcl/dynamic_estimation/kalman/constrained.py +382 -0
  45. pytcl/dynamic_estimation/kalman/extended.py +8 -8
  46. pytcl/dynamic_estimation/kalman/h_infinity.py +613 -0
  47. pytcl/dynamic_estimation/kalman/square_root.py +60 -573
  48. pytcl/dynamic_estimation/kalman/sr_ukf.py +302 -0
  49. pytcl/dynamic_estimation/kalman/ud_filter.py +410 -0
  50. pytcl/dynamic_estimation/kalman/unscented.py +8 -6
  51. pytcl/dynamic_estimation/particle_filters/bootstrap.py +15 -15
  52. pytcl/dynamic_estimation/rbpf.py +589 -0
  53. pytcl/gravity/egm.py +13 -0
  54. pytcl/gravity/spherical_harmonics.py +98 -37
  55. pytcl/gravity/tides.py +6 -6
  56. pytcl/logging_config.py +328 -0
  57. pytcl/magnetism/__init__.py +7 -0
  58. pytcl/magnetism/emm.py +10 -3
  59. pytcl/magnetism/wmm.py +260 -23
  60. pytcl/mathematical_functions/combinatorics/combinatorics.py +5 -5
  61. pytcl/mathematical_functions/geometry/geometry.py +5 -5
  62. pytcl/mathematical_functions/numerical_integration/quadrature.py +6 -6
  63. pytcl/mathematical_functions/signal_processing/detection.py +24 -24
  64. pytcl/mathematical_functions/signal_processing/filters.py +14 -14
  65. pytcl/mathematical_functions/signal_processing/matched_filter.py +12 -12
  66. pytcl/mathematical_functions/special_functions/bessel.py +15 -3
  67. pytcl/mathematical_functions/special_functions/debye.py +136 -26
  68. pytcl/mathematical_functions/special_functions/error_functions.py +3 -1
  69. pytcl/mathematical_functions/special_functions/gamma_functions.py +4 -4
  70. pytcl/mathematical_functions/special_functions/hypergeometric.py +81 -15
  71. pytcl/mathematical_functions/transforms/fourier.py +8 -8
  72. pytcl/mathematical_functions/transforms/stft.py +12 -12
  73. pytcl/mathematical_functions/transforms/wavelets.py +9 -9
  74. pytcl/navigation/geodesy.py +246 -160
  75. pytcl/navigation/great_circle.py +101 -19
  76. pytcl/plotting/coordinates.py +7 -7
  77. pytcl/plotting/tracks.py +2 -2
  78. pytcl/static_estimation/maximum_likelihood.py +16 -14
  79. pytcl/static_estimation/robust.py +5 -5
  80. pytcl/terrain/loaders.py +5 -5
  81. pytcl/trackers/hypothesis.py +1 -1
  82. pytcl/trackers/mht.py +9 -9
  83. pytcl/trackers/multi_target.py +1 -1
  84. {nrl_tracker-0.22.5.dist-info → nrl_tracker-1.8.0.dist-info}/LICENSE +0 -0
  85. {nrl_tracker-0.22.5.dist-info → nrl_tracker-1.8.0.dist-info}/WHEEL +0 -0
  86. {nrl_tracker-0.22.5.dist-info → nrl_tracker-1.8.0.dist-info}/top_level.txt +0 -0
@@ -13,11 +13,17 @@ References
13
13
  Finding Best Matches in Logarithmic Expected Time," ACM TOMS, 1977.
14
14
  """
15
15
 
16
+ import logging
16
17
  from typing import List, NamedTuple, Optional, Tuple
17
18
 
18
19
  import numpy as np
19
20
  from numpy.typing import ArrayLike, NDArray
20
21
 
22
+ from pytcl.containers.base import BaseSpatialIndex, validate_query_input
23
+
24
+ # Module logger
25
+ _logger = logging.getLogger("pytcl.containers.kd_tree")
26
+
21
27
 
22
28
  class KDNode:
23
29
  """A node in the k-d tree.
@@ -66,7 +72,7 @@ class NearestNeighborResult(NamedTuple):
66
72
  distances: NDArray[np.floating]
67
73
 
68
74
 
69
- class KDTree:
75
+ class KDTree(BaseSpatialIndex):
70
76
  """
71
77
  K-D Tree for efficient spatial queries.
72
78
 
@@ -97,6 +103,11 @@ class KDTree:
97
103
 
98
104
  Query complexity is O(log n) on average for nearest neighbor search,
99
105
  though worst case is O(n) for highly unbalanced queries.
106
+
107
+ See Also
108
+ --------
109
+ BaseSpatialIndex : Abstract base class defining the spatial index interface.
110
+ BallTree : Alternative spatial index using hyperspheres.
100
111
  """
101
112
 
102
113
  def __init__(
@@ -104,17 +115,13 @@ class KDTree:
104
115
  data: ArrayLike,
105
116
  leaf_size: int = 10,
106
117
  ):
107
- self.data = np.asarray(data, dtype=np.float64)
108
-
109
- if self.data.ndim != 2:
110
- raise ValueError("Data must be 2-dimensional (n_samples, n_features)")
111
-
112
- self.n_samples, self.n_features = self.data.shape
118
+ super().__init__(data)
113
119
  self.leaf_size = leaf_size
114
120
 
115
121
  # Build the tree
116
122
  indices = np.arange(self.n_samples)
117
123
  self.root = self._build_tree(indices, depth=0)
124
+ _logger.debug("KDTree built with leaf_size=%d", leaf_size)
118
125
 
119
126
  def _build_tree(
120
127
  self,
@@ -173,12 +180,9 @@ class KDTree:
173
180
  >>> result.indices
174
181
  array([[0, 1]])
175
182
  """
176
- X = np.asarray(X, dtype=np.float64)
177
-
178
- if X.ndim == 1:
179
- X = X.reshape(1, -1)
180
-
183
+ X = validate_query_input(X, self.n_features)
181
184
  n_queries = X.shape[0]
185
+ _logger.debug("KDTree.query: %d queries, k=%d", n_queries, k)
182
186
 
183
187
  all_indices = np.zeros((n_queries, k), dtype=np.intp)
184
188
  all_distances = np.full((n_queries, k), np.inf)
@@ -263,11 +267,7 @@ class KDTree:
263
267
  >>> tree.query_radius([[0, 0]], r=1.5)
264
268
  [[0, 1, 2]]
265
269
  """
266
- X = np.asarray(X, dtype=np.float64)
267
-
268
- if X.ndim == 1:
269
- X = X.reshape(1, -1)
270
-
270
+ X = validate_query_input(X, self.n_features)
271
271
  n_queries = X.shape[0]
272
272
  results: List[List[int]] = []
273
273
 
@@ -331,7 +331,7 @@ class KDTree:
331
331
  return self.query_radius(X, r)
332
332
 
333
333
 
334
- class BallTree:
334
+ class BallTree(BaseSpatialIndex):
335
335
  """
336
336
  Ball Tree for efficient spatial queries.
337
337
 
@@ -357,6 +357,11 @@ class BallTree:
357
357
  -----
358
358
  Ball trees have O(n log n) construction and O(log n) average-case
359
359
  query time. They can outperform k-d trees in high dimensions.
360
+
361
+ See Also
362
+ --------
363
+ BaseSpatialIndex : Abstract base class defining the spatial index interface.
364
+ KDTree : Alternative spatial index using axis-aligned splits.
360
365
  """
361
366
 
362
367
  def __init__(
@@ -364,12 +369,7 @@ class BallTree:
364
369
  data: ArrayLike,
365
370
  leaf_size: int = 10,
366
371
  ):
367
- self.data = np.asarray(data, dtype=np.float64)
368
-
369
- if self.data.ndim != 2:
370
- raise ValueError("Data must be 2-dimensional")
371
-
372
- self.n_samples, self.n_features = self.data.shape
372
+ super().__init__(data)
373
373
  self.leaf_size = leaf_size
374
374
 
375
375
  # Build tree using indices
@@ -382,6 +382,7 @@ class BallTree:
382
382
  self._leaf_indices: List[Optional[NDArray[np.intp]]] = []
383
383
 
384
384
  self._build_tree(self._indices)
385
+ _logger.debug("BallTree built with leaf_size=%d", leaf_size)
385
386
 
386
387
  def _build_tree(
387
388
  self,
@@ -459,11 +460,7 @@ class BallTree:
459
460
  result : NearestNeighborResult
460
461
  Indices and distances of k nearest neighbors.
461
462
  """
462
- X = np.asarray(X, dtype=np.float64)
463
-
464
- if X.ndim == 1:
465
- X = X.reshape(1, -1)
466
-
463
+ X = validate_query_input(X, self.n_features)
467
464
  n_queries = X.shape[0]
468
465
  all_indices = np.zeros((n_queries, k), dtype=np.intp)
469
466
  all_distances = np.full((n_queries, k), np.inf)
@@ -478,6 +475,74 @@ class BallTree:
478
475
 
479
476
  return NearestNeighborResult(indices=all_indices, distances=all_distances)
480
477
 
478
+ def query_radius(
479
+ self,
480
+ X: ArrayLike,
481
+ r: float,
482
+ ) -> List[List[int]]:
483
+ """
484
+ Query the tree for all points within radius r.
485
+
486
+ Parameters
487
+ ----------
488
+ X : array_like
489
+ Query points of shape (n_queries, n_features) or (n_features,).
490
+ r : float
491
+ Query radius.
492
+
493
+ Returns
494
+ -------
495
+ indices : list of lists
496
+ For each query, a list of indices of points within radius r.
497
+ """
498
+ X = validate_query_input(X, self.n_features)
499
+ n_queries = X.shape[0]
500
+ results: List[List[int]] = []
501
+
502
+ for i in range(n_queries):
503
+ indices = self._query_radius_single(X[i], r)
504
+ results.append(indices)
505
+
506
+ return results
507
+
508
+ def _query_radius_single(
509
+ self,
510
+ query: NDArray[np.floating],
511
+ r: float,
512
+ ) -> List[int]:
513
+ """Find all points within radius r of query point."""
514
+ indices: List[int] = []
515
+
516
+ def _search(node_id: int) -> None:
517
+ if node_id < 0:
518
+ return
519
+
520
+ centroid = self._centroids[node_id]
521
+ radius = self._radii[node_id]
522
+
523
+ # Distance to ball surface
524
+ dist_to_center = np.sqrt(np.sum((query - centroid) ** 2))
525
+
526
+ # Prune if ball is farther than radius
527
+ if dist_to_center - radius > r:
528
+ return
529
+
530
+ if self._is_leaf[node_id]:
531
+ # Check all points in leaf
532
+ leaf_indices = self._leaf_indices[node_id]
533
+ if leaf_indices is not None:
534
+ for idx in leaf_indices:
535
+ dist = np.sqrt(np.sum((query - self.data[idx]) ** 2))
536
+ if dist <= r:
537
+ indices.append(idx)
538
+ else:
539
+ # Visit both children
540
+ _search(self._left[node_id])
541
+ _search(self._right[node_id])
542
+
543
+ _search(0)
544
+ return indices
545
+
481
546
  def _query_single(
482
547
  self,
483
548
  query: NDArray[np.floating],
pytcl/containers/rtree.py CHANGED
@@ -13,11 +13,17 @@ References
13
13
  Method for Points and Rectangles," ACM SIGMOD, 1990.
14
14
  """
15
15
 
16
+ import logging
16
17
  from typing import List, NamedTuple, Optional, Tuple
17
18
 
18
19
  import numpy as np
19
20
  from numpy.typing import ArrayLike, NDArray
20
21
 
22
+ from pytcl.containers.base import SpatialQueryResult, validate_query_input
23
+
24
+ # Module logger
25
+ _logger = logging.getLogger("pytcl.containers.rtree")
26
+
21
27
 
22
28
  class BoundingBox(NamedTuple):
23
29
  """Axis-aligned bounding box.
@@ -160,6 +166,9 @@ class RTree:
160
166
  An R-tree groups nearby objects and represents them with their
161
167
  minimum bounding rectangle. This allows efficient spatial queries.
162
168
 
169
+ Unlike KDTree and BallTree which only index points, RTree can index
170
+ bounding boxes of arbitrary size. It also supports dynamic insertion.
171
+
163
172
  Parameters
164
173
  ----------
165
174
  max_entries : int, optional
@@ -167,6 +176,13 @@ class RTree:
167
176
  min_entries : int, optional
168
177
  Minimum entries per node (except root). Default max_entries // 2.
169
178
 
179
+ Attributes
180
+ ----------
181
+ n_entries : int
182
+ Number of entries in the tree.
183
+ n_features : int
184
+ Dimensionality of the data (set after first insertion).
185
+
170
186
  Examples
171
187
  --------
172
188
  >>> tree = RTree()
@@ -179,6 +195,11 @@ class RTree:
179
195
  -----
180
196
  This implementation uses a simplified insertion algorithm.
181
197
  For production use, consider using R*-tree or packed R-tree variants.
198
+
199
+ See Also
200
+ --------
201
+ KDTree : Point-based spatial index using axis-aligned splits.
202
+ BallTree : Point-based spatial index using hyperspheres.
182
203
  """
183
204
 
184
205
  def __init__(
@@ -190,11 +211,70 @@ class RTree:
190
211
  self.min_entries = min_entries or max_entries // 2
191
212
  self.root = RTreeNode(is_leaf=True)
192
213
  self.n_entries = 0
214
+ self.n_features: Optional[int] = None
193
215
  self._data: List[BoundingBox] = []
216
+ self._points: Optional[NDArray[np.floating]] = None
217
+ _logger.debug("RTree initialized with max_entries=%d", max_entries)
218
+
219
+ @classmethod
220
+ def from_points(
221
+ cls,
222
+ data: ArrayLike,
223
+ max_entries: int = 10,
224
+ min_entries: Optional[int] = None,
225
+ ) -> "RTree":
226
+ """
227
+ Create an RTree from point data.
228
+
229
+ This factory method provides an interface similar to KDTree and BallTree,
230
+ allowing RTree to be used interchangeably for point queries.
231
+
232
+ Parameters
233
+ ----------
234
+ data : array_like
235
+ Data points of shape (n_samples, n_features).
236
+ max_entries : int, optional
237
+ Maximum entries per node. Default 10.
238
+ min_entries : int, optional
239
+ Minimum entries per node. Default max_entries // 2.
240
+
241
+ Returns
242
+ -------
243
+ tree : RTree
244
+ RTree with all points inserted.
245
+
246
+ Examples
247
+ --------
248
+ >>> points = np.array([[0, 0], [1, 0], [0, 1], [1, 1]])
249
+ >>> tree = RTree.from_points(points)
250
+ >>> result = tree.query([[0.1, 0.1]], k=2)
251
+ """
252
+ data = np.asarray(data, dtype=np.float64)
253
+ if data.ndim != 2:
254
+ raise ValueError(
255
+ f"Data must be 2-dimensional (n_samples, n_features), "
256
+ f"got shape {data.shape}"
257
+ )
258
+
259
+ tree = cls(max_entries=max_entries, min_entries=min_entries)
260
+ tree._points = data
261
+ tree.n_features = data.shape[1]
262
+ tree.insert_points(data)
263
+ _logger.debug(
264
+ "RTree.from_points: indexed %d points in %d dimensions",
265
+ data.shape[0],
266
+ data.shape[1],
267
+ )
268
+ return tree
194
269
 
195
270
  def __len__(self) -> int:
196
271
  return self.n_entries
197
272
 
273
+ def __repr__(self) -> str:
274
+ if self.n_features is not None:
275
+ return f"RTree(n_entries={self.n_entries}, n_features={self.n_features})"
276
+ return f"RTree(n_entries={self.n_entries})"
277
+
198
278
  def insert(self, bbox: BoundingBox, data_index: Optional[int] = None) -> int:
199
279
  """
200
280
  Insert a bounding box into the tree.
@@ -214,6 +294,10 @@ class RTree:
214
294
  if data_index is None:
215
295
  data_index = self.n_entries
216
296
 
297
+ # Track dimensionality
298
+ if self.n_features is None:
299
+ self.n_features = len(bbox.min_coords)
300
+
217
301
  self._data.append(bbox)
218
302
 
219
303
  # Find leaf to insert into
@@ -330,6 +414,121 @@ class RTree:
330
414
  current.update_bbox()
331
415
  current = current.parent
332
416
 
417
+ def query(
418
+ self,
419
+ X: ArrayLike,
420
+ k: int = 1,
421
+ ) -> SpatialQueryResult:
422
+ """
423
+ Query the tree for k nearest neighbors.
424
+
425
+ This method provides API compatibility with KDTree and BallTree.
426
+
427
+ Parameters
428
+ ----------
429
+ X : array_like
430
+ Query points of shape (n_queries, n_features) or (n_features,).
431
+ k : int, optional
432
+ Number of nearest neighbors. Default 1.
433
+
434
+ Returns
435
+ -------
436
+ result : SpatialQueryResult
437
+ Indices and distances of k nearest neighbors for each query.
438
+
439
+ Examples
440
+ --------
441
+ >>> tree = RTree.from_points(np.array([[0, 0], [1, 1], [2, 2]]))
442
+ >>> result = tree.query([[0.5, 0.5]], k=2)
443
+ >>> result.indices
444
+ array([[0, 1]])
445
+ """
446
+ if self.n_features is None:
447
+ raise ValueError("Cannot query empty RTree")
448
+
449
+ X = validate_query_input(X, self.n_features)
450
+ n_queries = X.shape[0]
451
+ _logger.debug("RTree.query: %d queries, k=%d", n_queries, k)
452
+
453
+ all_indices = np.zeros((n_queries, k), dtype=np.intp)
454
+ all_distances = np.full((n_queries, k), np.inf)
455
+
456
+ for i in range(n_queries):
457
+ indices, distances = self.nearest(X[i], k=k)
458
+ n_found = len(indices)
459
+ if n_found > 0:
460
+ all_indices[i, :n_found] = indices
461
+ all_distances[i, :n_found] = distances
462
+
463
+ return SpatialQueryResult(indices=all_indices, distances=all_distances)
464
+
465
+ def query_radius(
466
+ self,
467
+ X: ArrayLike,
468
+ r: float,
469
+ ) -> List[List[int]]:
470
+ """
471
+ Query the tree for all points within radius r.
472
+
473
+ This method provides API compatibility with KDTree and BallTree.
474
+
475
+ Parameters
476
+ ----------
477
+ X : array_like
478
+ Query points of shape (n_queries, n_features) or (n_features,).
479
+ r : float
480
+ Query radius.
481
+
482
+ Returns
483
+ -------
484
+ indices : list of lists
485
+ For each query, a list of indices of points within radius r.
486
+
487
+ Examples
488
+ --------
489
+ >>> tree = RTree.from_points(np.array([[0, 0], [1, 0], [0, 1], [5, 5]]))
490
+ >>> tree.query_radius([[0, 0]], r=1.5)
491
+ [[0, 1, 2]]
492
+ """
493
+ if self.n_features is None:
494
+ raise ValueError("Cannot query empty RTree")
495
+
496
+ X = validate_query_input(X, self.n_features)
497
+ n_queries = X.shape[0]
498
+ results: List[List[int]] = []
499
+
500
+ for i in range(n_queries):
501
+ query = X[i]
502
+ indices: List[int] = []
503
+
504
+ def search(node: RTreeNode) -> None:
505
+ if node.bbox is None:
506
+ return
507
+
508
+ # Minimum distance from query point to node's bounding box
509
+ clamped = np.clip(query, node.bbox.min_coords, node.bbox.max_coords)
510
+ min_dist = float(np.sqrt(np.sum((query - clamped) ** 2)))
511
+
512
+ # Prune if node is entirely outside radius
513
+ if min_dist > r:
514
+ return
515
+
516
+ if node.is_leaf:
517
+ for bbox, idx in node.entries:
518
+ # Distance to point (center of zero-volume box)
519
+ clamped_pt = np.clip(query, bbox.min_coords, bbox.max_coords)
520
+ dist = float(np.sqrt(np.sum((query - clamped_pt) ** 2)))
521
+ if dist <= r:
522
+ indices.append(idx)
523
+ else:
524
+ for child in node.children:
525
+ search(child)
526
+
527
+ search(self.root)
528
+ results.append(indices)
529
+
530
+ return results
531
+
333
532
  def query_intersect(self, query_bbox: BoundingBox) -> RTreeResult:
334
533
  """
335
534
  Find all entries intersecting a query box.
@@ -455,7 +654,7 @@ class RTree:
455
654
  query = np.asarray(query_point, dtype=np.float64)
456
655
  neighbors: List[Tuple[float, int]] = []
457
656
 
458
- def min_dist_to_box(point: NDArray, bbox: BoundingBox) -> float:
657
+ def min_dist_to_box(point: NDArray[np.floating], bbox: BoundingBox) -> float:
459
658
  """Minimum distance from point to bounding box."""
460
659
  clamped = np.clip(point, bbox.min_coords, bbox.max_coords)
461
660
  return float(np.sqrt(np.sum((point - clamped) ** 2)))
@@ -11,11 +11,17 @@ References
11
11
  neighbor search in general metric spaces," SODA 1993.
12
12
  """
13
13
 
14
- from typing import Callable, List, NamedTuple, Optional, Tuple
14
+ import logging
15
+ from typing import Any, Callable, List, NamedTuple, Optional, Tuple
15
16
 
16
17
  import numpy as np
17
18
  from numpy.typing import ArrayLike, NDArray
18
19
 
20
+ from pytcl.containers.base import MetricSpatialIndex, validate_query_input
21
+
22
+ # Module logger
23
+ _logger = logging.getLogger("pytcl.containers.vptree")
24
+
19
25
 
20
26
  class VPTreeResult(NamedTuple):
21
27
  """Result of VP-tree query.
@@ -56,7 +62,7 @@ class VPNode:
56
62
  self.right: Optional["VPNode"] = None
57
63
 
58
64
 
59
- class VPTree:
65
+ class VPTree(MetricSpatialIndex):
60
66
  """
61
67
  Vantage Point Tree for metric space nearest neighbor search.
62
68
 
@@ -89,32 +95,27 @@ class VPTree:
89
95
 
90
96
  Query complexity is O(log n) on average but can degrade to O(n)
91
97
  for pathological distance distributions.
98
+
99
+ See Also
100
+ --------
101
+ MetricSpatialIndex : Abstract base class for metric-based spatial indices.
102
+ CoverTree : Alternative metric space index with theoretical guarantees.
92
103
  """
93
104
 
94
105
  def __init__(
95
106
  self,
96
107
  data: ArrayLike,
97
- metric: Optional[Callable[[NDArray, NDArray], float]] = None,
108
+ metric: Optional[
109
+ Callable[[np.ndarray[Any, Any], np.ndarray[Any, Any]], float]
110
+ ] = None,
98
111
  ):
99
- self.data = np.asarray(data, dtype=np.float64)
100
-
101
- if self.data.ndim != 2:
102
- raise ValueError("Data must be 2-dimensional")
103
-
104
- self.n_samples, self.n_features = self.data.shape
105
-
106
- if metric is None:
107
- self.metric = self._euclidean_distance
108
- else:
109
- self.metric = metric
112
+ super().__init__(data, metric)
110
113
 
111
114
  # Build tree
112
115
  indices = np.arange(self.n_samples)
113
116
  self.root = self._build_tree(indices)
114
-
115
- def _euclidean_distance(self, x: NDArray, y: NDArray) -> float:
116
- """Default Euclidean distance metric."""
117
- return float(np.sqrt(np.sum((x - y) ** 2)))
117
+ metric_name = metric.__name__ if metric else "euclidean"
118
+ _logger.debug("VPTree built with metric=%s", metric_name)
118
119
 
119
120
  def _build_tree(self, indices: NDArray[np.intp]) -> Optional[VPNode]:
120
121
  """Recursively build the VP-tree."""
@@ -169,11 +170,7 @@ class VPTree:
169
170
  result : VPTreeResult
170
171
  Indices and distances of k nearest neighbors.
171
172
  """
172
- X = np.asarray(X, dtype=np.float64)
173
-
174
- if X.ndim == 1:
175
- X = X.reshape(1, -1)
176
-
173
+ X = validate_query_input(X, self.n_features)
177
174
  n_queries = X.shape[0]
178
175
 
179
176
  all_indices = np.zeros((n_queries, k), dtype=np.intp)
@@ -259,11 +256,7 @@ class VPTree:
259
256
  indices : list of lists
260
257
  For each query, list of indices within radius.
261
258
  """
262
- X = np.asarray(X, dtype=np.float64)
263
-
264
- if X.ndim == 1:
265
- X = X.reshape(1, -1)
266
-
259
+ X = validate_query_input(X, self.n_features)
267
260
  n_queries = X.shape[0]
268
261
  results: List[List[int]] = []
269
262