nrl-tracker 0.21.4__py3-none-any.whl → 1.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. {nrl_tracker-0.21.4.dist-info → nrl_tracker-1.7.5.dist-info}/METADATA +57 -10
  2. nrl_tracker-1.7.5.dist-info/RECORD +165 -0
  3. pytcl/__init__.py +4 -3
  4. pytcl/assignment_algorithms/__init__.py +28 -0
  5. pytcl/assignment_algorithms/data_association.py +2 -7
  6. pytcl/assignment_algorithms/gating.py +10 -10
  7. pytcl/assignment_algorithms/jpda.py +40 -40
  8. pytcl/assignment_algorithms/nd_assignment.py +379 -0
  9. pytcl/assignment_algorithms/network_flow.py +371 -0
  10. pytcl/assignment_algorithms/three_dimensional/assignment.py +3 -3
  11. pytcl/astronomical/__init__.py +162 -8
  12. pytcl/astronomical/ephemerides.py +533 -0
  13. pytcl/astronomical/reference_frames.py +865 -56
  14. pytcl/astronomical/relativity.py +473 -0
  15. pytcl/astronomical/sgp4.py +710 -0
  16. pytcl/astronomical/special_orbits.py +532 -0
  17. pytcl/astronomical/tle.py +558 -0
  18. pytcl/atmosphere/__init__.py +45 -3
  19. pytcl/atmosphere/ionosphere.py +512 -0
  20. pytcl/atmosphere/nrlmsise00.py +809 -0
  21. pytcl/clustering/dbscan.py +2 -2
  22. pytcl/clustering/gaussian_mixture.py +3 -3
  23. pytcl/clustering/hierarchical.py +15 -15
  24. pytcl/clustering/kmeans.py +4 -4
  25. pytcl/containers/__init__.py +28 -21
  26. pytcl/containers/base.py +219 -0
  27. pytcl/containers/cluster_set.py +2 -1
  28. pytcl/containers/covertree.py +26 -29
  29. pytcl/containers/kd_tree.py +94 -29
  30. pytcl/containers/measurement_set.py +1 -9
  31. pytcl/containers/rtree.py +200 -1
  32. pytcl/containers/vptree.py +21 -28
  33. pytcl/coordinate_systems/conversions/geodetic.py +272 -5
  34. pytcl/coordinate_systems/jacobians/jacobians.py +2 -2
  35. pytcl/coordinate_systems/projections/__init__.py +4 -2
  36. pytcl/coordinate_systems/projections/projections.py +2 -2
  37. pytcl/coordinate_systems/rotations/rotations.py +10 -6
  38. pytcl/core/__init__.py +18 -0
  39. pytcl/core/validation.py +333 -2
  40. pytcl/dynamic_estimation/__init__.py +26 -0
  41. pytcl/dynamic_estimation/gaussian_sum_filter.py +434 -0
  42. pytcl/dynamic_estimation/imm.py +15 -18
  43. pytcl/dynamic_estimation/kalman/__init__.py +30 -0
  44. pytcl/dynamic_estimation/kalman/constrained.py +382 -0
  45. pytcl/dynamic_estimation/kalman/extended.py +9 -12
  46. pytcl/dynamic_estimation/kalman/h_infinity.py +613 -0
  47. pytcl/dynamic_estimation/kalman/square_root.py +60 -573
  48. pytcl/dynamic_estimation/kalman/sr_ukf.py +302 -0
  49. pytcl/dynamic_estimation/kalman/ud_filter.py +410 -0
  50. pytcl/dynamic_estimation/kalman/unscented.py +9 -10
  51. pytcl/dynamic_estimation/particle_filters/bootstrap.py +15 -15
  52. pytcl/dynamic_estimation/rbpf.py +589 -0
  53. pytcl/dynamic_estimation/smoothers.py +1 -5
  54. pytcl/dynamic_models/discrete_time/__init__.py +1 -5
  55. pytcl/dynamic_models/process_noise/__init__.py +1 -5
  56. pytcl/gravity/egm.py +13 -0
  57. pytcl/gravity/spherical_harmonics.py +98 -37
  58. pytcl/gravity/tides.py +6 -6
  59. pytcl/logging_config.py +328 -0
  60. pytcl/magnetism/__init__.py +10 -14
  61. pytcl/magnetism/emm.py +10 -3
  62. pytcl/magnetism/wmm.py +260 -23
  63. pytcl/mathematical_functions/combinatorics/combinatorics.py +5 -5
  64. pytcl/mathematical_functions/geometry/geometry.py +5 -5
  65. pytcl/mathematical_functions/interpolation/__init__.py +2 -2
  66. pytcl/mathematical_functions/numerical_integration/quadrature.py +6 -6
  67. pytcl/mathematical_functions/signal_processing/detection.py +24 -24
  68. pytcl/mathematical_functions/signal_processing/filters.py +14 -14
  69. pytcl/mathematical_functions/signal_processing/matched_filter.py +12 -12
  70. pytcl/mathematical_functions/special_functions/__init__.py +2 -2
  71. pytcl/mathematical_functions/special_functions/bessel.py +15 -3
  72. pytcl/mathematical_functions/special_functions/debye.py +136 -26
  73. pytcl/mathematical_functions/special_functions/error_functions.py +3 -1
  74. pytcl/mathematical_functions/special_functions/gamma_functions.py +4 -4
  75. pytcl/mathematical_functions/special_functions/hypergeometric.py +81 -15
  76. pytcl/mathematical_functions/transforms/fourier.py +8 -8
  77. pytcl/mathematical_functions/transforms/stft.py +12 -12
  78. pytcl/mathematical_functions/transforms/wavelets.py +9 -9
  79. pytcl/navigation/__init__.py +14 -10
  80. pytcl/navigation/geodesy.py +246 -160
  81. pytcl/navigation/great_circle.py +101 -19
  82. pytcl/navigation/ins.py +1 -5
  83. pytcl/plotting/coordinates.py +7 -7
  84. pytcl/plotting/tracks.py +2 -2
  85. pytcl/static_estimation/maximum_likelihood.py +16 -14
  86. pytcl/static_estimation/robust.py +5 -5
  87. pytcl/terrain/loaders.py +5 -5
  88. pytcl/trackers/__init__.py +3 -14
  89. pytcl/trackers/hypothesis.py +1 -1
  90. pytcl/trackers/mht.py +9 -9
  91. pytcl/trackers/multi_target.py +2 -5
  92. nrl_tracker-0.21.4.dist-info/RECORD +0 -148
  93. {nrl_tracker-0.21.4.dist-info → nrl_tracker-1.7.5.dist-info}/LICENSE +0 -0
  94. {nrl_tracker-0.21.4.dist-info → nrl_tracker-1.7.5.dist-info}/WHEEL +0 -0
  95. {nrl_tracker-0.21.4.dist-info → nrl_tracker-1.7.5.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@ References
12
12
  with Noise," KDD 1996.
13
13
  """
14
14
 
15
- from typing import List, NamedTuple, Set
15
+ from typing import Any, List, NamedTuple, Set
16
16
 
17
17
  import numpy as np
18
18
  from numba import njit
@@ -42,7 +42,7 @@ class DBSCANResult(NamedTuple):
42
42
 
43
43
 
44
44
  @njit(cache=True)
45
- def _compute_distance_matrix(X: np.ndarray) -> np.ndarray:
45
+ def _compute_distance_matrix(X: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
46
46
  """Compute pairwise Euclidean distance matrix (JIT-compiled)."""
47
47
  n = X.shape[0]
48
48
  dist = np.zeros((n, n), dtype=np.float64)
@@ -674,9 +674,9 @@ class GaussianMixture:
674
674
 
675
675
  def _gaussian_pdf(
676
676
  self,
677
- x: NDArray,
678
- mean: NDArray,
679
- cov: NDArray,
677
+ x: NDArray[np.floating],
678
+ mean: NDArray[np.floating],
679
+ cov: NDArray[np.floating],
680
680
  ) -> float:
681
681
  """Evaluate single Gaussian PDF."""
682
682
  n = len(x)
@@ -12,7 +12,7 @@ References
12
12
  """
13
13
 
14
14
  from enum import Enum
15
- from typing import List, Literal, NamedTuple, Optional
15
+ from typing import Any, List, Literal, NamedTuple, Optional
16
16
 
17
17
  import numpy as np
18
18
  from numba import njit
@@ -72,7 +72,7 @@ class HierarchicalResult(NamedTuple):
72
72
 
73
73
 
74
74
  @njit(cache=True)
75
- def _compute_distance_matrix_jit(X: np.ndarray) -> np.ndarray:
75
+ def _compute_distance_matrix_jit(X: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
76
76
  """JIT-compiled pairwise Euclidean distance computation."""
77
77
  n = X.shape[0]
78
78
  n_features = X.shape[1]
@@ -112,43 +112,43 @@ def compute_distance_matrix(
112
112
 
113
113
 
114
114
  def _single_linkage(
115
- dist_i: NDArray,
116
- dist_j: NDArray,
115
+ dist_i: NDArray[Any],
116
+ dist_j: NDArray[Any],
117
117
  size_i: int,
118
118
  size_j: int,
119
- ) -> NDArray:
119
+ ) -> NDArray[Any]:
120
120
  """Single linkage: minimum of distances."""
121
121
  return np.minimum(dist_i, dist_j)
122
122
 
123
123
 
124
124
  def _complete_linkage(
125
- dist_i: NDArray,
126
- dist_j: NDArray,
125
+ dist_i: NDArray[Any],
126
+ dist_j: NDArray[Any],
127
127
  size_i: int,
128
128
  size_j: int,
129
- ) -> NDArray:
129
+ ) -> NDArray[Any]:
130
130
  """Complete linkage: maximum of distances."""
131
131
  return np.maximum(dist_i, dist_j)
132
132
 
133
133
 
134
134
  def _average_linkage(
135
- dist_i: NDArray,
136
- dist_j: NDArray,
135
+ dist_i: NDArray[Any],
136
+ dist_j: NDArray[Any],
137
137
  size_i: int,
138
138
  size_j: int,
139
- ) -> NDArray:
139
+ ) -> NDArray[Any]:
140
140
  """Average linkage: weighted average of distances."""
141
141
  return (size_i * dist_i + size_j * dist_j) / (size_i + size_j)
142
142
 
143
143
 
144
144
  def _ward_linkage(
145
- dist_i: NDArray,
146
- dist_j: NDArray,
145
+ dist_i: NDArray[Any],
146
+ dist_j: NDArray[Any],
147
147
  size_i: int,
148
148
  size_j: int,
149
- size_k: NDArray,
149
+ size_k: NDArray[Any],
150
150
  dist_ij: float,
151
- ) -> NDArray:
151
+ ) -> NDArray[Any]:
152
152
  """Ward's linkage: minimum variance merge."""
153
153
  total = size_i + size_j + size_k
154
154
  return np.sqrt(
@@ -10,7 +10,7 @@ References
10
10
  Careful Seeding," SODA 2007.
11
11
  """
12
12
 
13
- from typing import Literal, NamedTuple, Optional, Union
13
+ from typing import Any, Literal, NamedTuple, Optional, Union
14
14
 
15
15
  import numpy as np
16
16
  from numpy.typing import ArrayLike, NDArray
@@ -305,7 +305,7 @@ def _kmeans_single(
305
305
 
306
306
  # Handle empty clusters: keep old center
307
307
  for k in range(n_clusters):
308
- if np.all(new_centers[k] == 0) and np.any(labels == k) is False:
308
+ if np.all(new_centers[k] == 0) and not np.any(labels == k):
309
309
  new_centers[k] = centers[k]
310
310
 
311
311
  # Check convergence
@@ -336,8 +336,8 @@ def _kmeans_single(
336
336
  def kmeans_elbow(
337
337
  X: ArrayLike,
338
338
  k_range: Optional[range] = None,
339
- **kwargs,
340
- ) -> dict:
339
+ **kwargs: Any,
340
+ ) -> dict[str, Any]:
341
341
  """
342
342
  Compute K-means for a range of k values for elbow method.
343
343
 
@@ -3,8 +3,27 @@ Containers module.
3
3
 
4
4
  This module provides spatial data structures for efficient
5
5
  nearest neighbor queries, spatial indexing, and tracking containers.
6
+
7
+ Spatial Index Hierarchy
8
+ -----------------------
9
+ All spatial index structures inherit from BaseSpatialIndex which defines
10
+ a common interface for k-nearest neighbor and radius queries:
11
+
12
+ BaseSpatialIndex (abstract)
13
+ ├── KDTree - K-dimensional tree (Euclidean space)
14
+ ├── BallTree - Ball tree variant of KD-tree
15
+ ├── RTree - Rectangle tree for bounding boxes
16
+ └── MetricSpatialIndex (abstract)
17
+ ├── VPTree - Vantage point tree (any metric)
18
+ └── CoverTree - Cover tree (any metric)
6
19
  """
7
20
 
21
+ from pytcl.containers.base import (
22
+ BaseSpatialIndex,
23
+ MetricSpatialIndex,
24
+ SpatialQueryResult,
25
+ validate_query_input,
26
+ )
8
27
  from pytcl.containers.cluster_set import (
9
28
  ClusterSet,
10
29
  ClusterStats,
@@ -13,17 +32,8 @@ from pytcl.containers.cluster_set import (
13
32
  cluster_tracks_kmeans,
14
33
  compute_cluster_centroid,
15
34
  )
16
- from pytcl.containers.covertree import (
17
- CoverTree,
18
- CoverTreeNode,
19
- CoverTreeResult,
20
- )
21
- from pytcl.containers.kd_tree import (
22
- BallTree,
23
- KDNode,
24
- KDTree,
25
- NearestNeighborResult,
26
- )
35
+ from pytcl.containers.covertree import CoverTree, CoverTreeNode, CoverTreeResult
36
+ from pytcl.containers.kd_tree import BallTree, KDNode, KDTree, NearestNeighborResult
27
37
  from pytcl.containers.measurement_set import (
28
38
  Measurement,
29
39
  MeasurementQuery,
@@ -38,18 +48,15 @@ from pytcl.containers.rtree import (
38
48
  box_from_points,
39
49
  merge_boxes,
40
50
  )
41
- from pytcl.containers.track_list import (
42
- TrackList,
43
- TrackListStats,
44
- TrackQuery,
45
- )
46
- from pytcl.containers.vptree import (
47
- VPNode,
48
- VPTree,
49
- VPTreeResult,
50
- )
51
+ from pytcl.containers.track_list import TrackList, TrackListStats, TrackQuery
52
+ from pytcl.containers.vptree import VPNode, VPTree, VPTreeResult
51
53
 
52
54
  __all__ = [
55
+ # Base classes
56
+ "BaseSpatialIndex",
57
+ "MetricSpatialIndex",
58
+ "SpatialQueryResult",
59
+ "validate_query_input",
53
60
  # K-D Tree
54
61
  "KDNode",
55
62
  "NearestNeighborResult",
@@ -0,0 +1,219 @@
1
+ """
2
+ Base classes for spatial data structures.
3
+
4
+ This module provides abstract base classes that define the common interface
5
+ for spatial indexing data structures like KD-trees, VP-trees, R-trees, and
6
+ Cover trees.
7
+ """
8
+
9
+ import logging
10
+ from abc import ABC, abstractmethod
11
+ from typing import Any, Callable, List, NamedTuple, Optional
12
+
13
+ import numpy as np
14
+ from numpy.typing import ArrayLike, NDArray
15
+
16
+ # Module logger
17
+ _logger = logging.getLogger("pytcl.containers")
18
+
19
+
20
+ class SpatialQueryResult(NamedTuple):
21
+ """Result of a spatial query.
22
+
23
+ Attributes
24
+ ----------
25
+ indices : ndarray
26
+ Indices of matching points in the original data.
27
+ distances : ndarray
28
+ Distances to matching points.
29
+ """
30
+
31
+ indices: NDArray[np.intp]
32
+ distances: NDArray[np.floating]
33
+
34
+
35
+ class BaseSpatialIndex(ABC):
36
+ """
37
+ Abstract base class for spatial indexing data structures.
38
+
39
+ All spatial index implementations (KDTree, VPTree, RTree, CoverTree)
40
+ should inherit from this class and implement the required methods.
41
+
42
+ This provides a consistent interface for:
43
+ - Building the index from point data
44
+ - k-nearest neighbor queries
45
+ - Range/radius queries
46
+ - Dimension and size introspection
47
+
48
+ Parameters
49
+ ----------
50
+ data : array_like
51
+ Data points of shape (n_samples, n_features).
52
+
53
+ Attributes
54
+ ----------
55
+ data : ndarray
56
+ The indexed data points.
57
+ n_samples : int
58
+ Number of data points.
59
+ n_features : int
60
+ Dimensionality of data points.
61
+ """
62
+
63
+ def __init__(self, data: ArrayLike):
64
+ self.data = np.asarray(data, dtype=np.float64)
65
+
66
+ if self.data.ndim != 2:
67
+ raise ValueError(
68
+ f"Data must be 2-dimensional (n_samples, n_features), "
69
+ f"got shape {self.data.shape}"
70
+ )
71
+
72
+ self.n_samples, self.n_features = self.data.shape
73
+ _logger.debug(
74
+ "%s initialized with %d points in %d dimensions",
75
+ self.__class__.__name__,
76
+ self.n_samples,
77
+ self.n_features,
78
+ )
79
+
80
+ @abstractmethod
81
+ def query(
82
+ self,
83
+ X: ArrayLike,
84
+ k: int = 1,
85
+ ) -> SpatialQueryResult:
86
+ """
87
+ Query the index for k nearest neighbors.
88
+
89
+ Parameters
90
+ ----------
91
+ X : array_like
92
+ Query points of shape (n_queries, n_features) or (n_features,).
93
+ k : int, optional
94
+ Number of nearest neighbors to return. Default is 1.
95
+
96
+ Returns
97
+ -------
98
+ result : SpatialQueryResult
99
+ Named tuple with indices and distances of k nearest neighbors
100
+ for each query point.
101
+ """
102
+ pass
103
+
104
+ @abstractmethod
105
+ def query_radius(
106
+ self,
107
+ X: ArrayLike,
108
+ r: float,
109
+ ) -> List[List[int]]:
110
+ """
111
+ Query the index for all points within radius r.
112
+
113
+ Parameters
114
+ ----------
115
+ X : array_like
116
+ Query points of shape (n_queries, n_features) or (n_features,).
117
+ r : float
118
+ Search radius.
119
+
120
+ Returns
121
+ -------
122
+ indices : list of list of int
123
+ For each query point, a list of indices of data points
124
+ within distance r.
125
+ """
126
+ pass
127
+
128
+ def __len__(self) -> int:
129
+ """Return number of indexed points."""
130
+ return self.n_samples
131
+
132
+ def __repr__(self) -> str:
133
+ return (
134
+ f"{self.__class__.__name__}("
135
+ f"n_samples={self.n_samples}, n_features={self.n_features})"
136
+ )
137
+
138
+
139
+ class MetricSpatialIndex(BaseSpatialIndex):
140
+ """
141
+ Base class for metric space spatial indices.
142
+
143
+ Extends BaseSpatialIndex with support for custom distance metrics.
144
+ Used by VP-trees and Cover trees which can work with any metric.
145
+
146
+ Parameters
147
+ ----------
148
+ data : array_like
149
+ Data points of shape (n_samples, n_features).
150
+ metric : callable, optional
151
+ Distance function with signature metric(x, y) -> float.
152
+ Default is Euclidean distance.
153
+ """
154
+
155
+ def __init__(
156
+ self,
157
+ data: ArrayLike,
158
+ metric: Optional[Callable[[NDArray[Any], NDArray[Any]], float]] = None,
159
+ ):
160
+ super().__init__(data)
161
+
162
+ if metric is None:
163
+ self.metric = self._euclidean_distance
164
+ else:
165
+ self.metric = metric
166
+
167
+ @staticmethod
168
+ def _euclidean_distance(x: NDArray[Any], y: NDArray[Any]) -> float:
169
+ """Default Euclidean distance metric."""
170
+ return float(np.sqrt(np.sum((x - y) ** 2)))
171
+
172
+
173
+ def validate_query_input(
174
+ X: ArrayLike,
175
+ n_features: int,
176
+ ) -> NDArray[np.floating]:
177
+ """
178
+ Validate and reshape query input.
179
+
180
+ Parameters
181
+ ----------
182
+ X : array_like
183
+ Query points.
184
+ n_features : int
185
+ Expected number of features.
186
+
187
+ Returns
188
+ -------
189
+ X : ndarray
190
+ Validated query array of shape (n_queries, n_features).
191
+
192
+ Raises
193
+ ------
194
+ ValueError
195
+ If query has wrong number of features.
196
+ """
197
+ X = np.asarray(X, dtype=np.float64)
198
+
199
+ if X.ndim == 1:
200
+ X = X.reshape(1, -1)
201
+
202
+ if X.shape[1] != n_features:
203
+ _logger.warning(
204
+ "Query feature mismatch: got %d, expected %d", X.shape[1], n_features
205
+ )
206
+ raise ValueError(f"Query has {X.shape[1]} features, expected {n_features}")
207
+
208
+ _logger.debug(
209
+ "Validated query input: %d queries, %d features", X.shape[0], X.shape[1]
210
+ )
211
+ return X
212
+
213
+
214
+ __all__ = [
215
+ "SpatialQueryResult",
216
+ "BaseSpatialIndex",
217
+ "MetricSpatialIndex",
218
+ "validate_query_input",
219
+ ]
@@ -8,6 +8,7 @@ that move together (formations, convoys, etc.).
8
8
  from __future__ import annotations
9
9
 
10
10
  from typing import (
11
+ Any,
11
12
  Dict,
12
13
  Iterable,
13
14
  Iterator,
@@ -312,7 +313,7 @@ class ClusterSet:
312
313
  cls,
313
314
  tracks: TrackList,
314
315
  method: str = "dbscan",
315
- **kwargs,
316
+ **kwargs: Any,
316
317
  ) -> ClusterSet:
317
318
  """
318
319
  Create a ClusterSet by clustering tracks.
@@ -11,11 +11,17 @@ References
11
11
  neighbor," ICML 2006.
12
12
  """
13
13
 
14
- from typing import Callable, List, NamedTuple, Optional, Set, Tuple
14
+ import logging
15
+ from typing import Any, Callable, List, NamedTuple, Optional, Set, Tuple
15
16
 
16
17
  import numpy as np
17
18
  from numpy.typing import ArrayLike, NDArray
18
19
 
20
+ from pytcl.containers.base import MetricSpatialIndex, validate_query_input
21
+
22
+ # Module logger
23
+ _logger = logging.getLogger("pytcl.containers.covertree")
24
+
19
25
 
20
26
  class CoverTreeResult(NamedTuple):
21
27
  """Result of Cover tree query.
@@ -60,7 +66,7 @@ class CoverTreeNode:
60
66
  self.children[level].append(child)
61
67
 
62
68
 
63
- class CoverTree:
69
+ class CoverTree(MetricSpatialIndex):
64
70
  """
65
71
  Cover Tree for metric space nearest neighbor search.
66
72
 
@@ -93,27 +99,24 @@ class CoverTree:
93
99
 
94
100
  The implementation uses a simplified version of the original
95
101
  algorithm for clarity.
102
+
103
+ See Also
104
+ --------
105
+ MetricSpatialIndex : Abstract base class for metric-based spatial indices.
106
+ VPTree : Alternative metric space index using vantage points.
96
107
  """
97
108
 
98
109
  def __init__(
99
110
  self,
100
111
  data: ArrayLike,
101
- metric: Optional[Callable[[NDArray, NDArray], float]] = None,
112
+ metric: Optional[
113
+ Callable[[np.ndarray[Any, Any], np.ndarray[Any, Any]], float]
114
+ ] = None,
102
115
  base: float = 2.0,
103
116
  ):
104
- self.data = np.asarray(data, dtype=np.float64)
105
-
106
- if self.data.ndim != 2:
107
- raise ValueError("Data must be 2-dimensional")
108
-
109
- self.n_samples, self.n_features = self.data.shape
117
+ super().__init__(data, metric)
110
118
  self.base = base
111
119
 
112
- if metric is None:
113
- self.metric = self._euclidean_distance
114
- else:
115
- self.metric = metric
116
-
117
120
  # Compute distance cache for small datasets
118
121
  self._distance_cache: dict[Tuple[int, int], float] = {}
119
122
 
@@ -124,10 +127,12 @@ class CoverTree:
124
127
 
125
128
  if self.n_samples > 0:
126
129
  self._build_tree()
127
-
128
- def _euclidean_distance(self, x: NDArray, y: NDArray) -> float:
129
- """Default Euclidean distance metric."""
130
- return float(np.sqrt(np.sum((x - y) ** 2)))
130
+ _logger.debug(
131
+ "CoverTree built with base=%.1f, levels=%d to %d",
132
+ base,
133
+ self.min_level,
134
+ self.max_level,
135
+ )
131
136
 
132
137
  def _distance(self, i: int, j: int) -> float:
133
138
  """Get distance between points i and j (with caching)."""
@@ -138,7 +143,7 @@ class CoverTree:
138
143
  self._distance_cache[key] = self.metric(self.data[i], self.data[j])
139
144
  return self._distance_cache[key]
140
145
 
141
- def _distance_to_point(self, idx: int, query: NDArray) -> float:
146
+ def _distance_to_point(self, idx: int, query: NDArray[np.floating]) -> float:
142
147
  """Distance from data point to query point."""
143
148
  return self.metric(self.data[idx], query)
144
149
 
@@ -245,11 +250,7 @@ class CoverTree:
245
250
  result : CoverTreeResult
246
251
  Indices and distances of k nearest neighbors.
247
252
  """
248
- X = np.asarray(X, dtype=np.float64)
249
-
250
- if X.ndim == 1:
251
- X = X.reshape(1, -1)
252
-
253
+ X = validate_query_input(X, self.n_features)
253
254
  n_queries = X.shape[0]
254
255
 
255
256
  all_indices = np.zeros((n_queries, k), dtype=np.intp)
@@ -368,11 +369,7 @@ class CoverTree:
368
369
  indices : list of lists
369
370
  For each query, list of indices within radius.
370
371
  """
371
- X = np.asarray(X, dtype=np.float64)
372
-
373
- if X.ndim == 1:
374
- X = X.reshape(1, -1)
375
-
372
+ X = validate_query_input(X, self.n_features)
376
373
  n_queries = X.shape[0]
377
374
  results: List[List[int]] = []
378
375