nrl-tracker 1.1.3__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nrl_tracker-1.1.3.dist-info → nrl_tracker-1.3.0.dist-info}/METADATA +1 -1
- {nrl_tracker-1.1.3.dist-info → nrl_tracker-1.3.0.dist-info}/RECORD +28 -24
- pytcl/__init__.py +1 -1
- pytcl/astronomical/reference_frames.py +127 -55
- pytcl/atmosphere/__init__.py +32 -1
- pytcl/atmosphere/ionosphere.py +512 -0
- pytcl/containers/__init__.py +24 -0
- pytcl/containers/base.py +219 -0
- pytcl/containers/covertree.py +21 -26
- pytcl/containers/kd_tree.py +94 -29
- pytcl/containers/rtree.py +199 -0
- pytcl/containers/vptree.py +17 -26
- pytcl/core/__init__.py +18 -0
- pytcl/core/validation.py +331 -0
- pytcl/dynamic_estimation/kalman/square_root.py +52 -571
- pytcl/dynamic_estimation/kalman/sr_ukf.py +302 -0
- pytcl/dynamic_estimation/kalman/ud_filter.py +404 -0
- pytcl/gravity/egm.py +13 -0
- pytcl/gravity/spherical_harmonics.py +97 -36
- pytcl/magnetism/__init__.py +7 -0
- pytcl/magnetism/wmm.py +260 -23
- pytcl/mathematical_functions/special_functions/debye.py +132 -26
- pytcl/mathematical_functions/special_functions/hypergeometric.py +79 -15
- pytcl/navigation/geodesy.py +245 -159
- pytcl/navigation/great_circle.py +98 -16
- {nrl_tracker-1.1.3.dist-info → nrl_tracker-1.3.0.dist-info}/LICENSE +0 -0
- {nrl_tracker-1.1.3.dist-info → nrl_tracker-1.3.0.dist-info}/WHEEL +0 -0
- {nrl_tracker-1.1.3.dist-info → nrl_tracker-1.3.0.dist-info}/top_level.txt +0 -0
pytcl/containers/base.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base classes for spatial data structures.
|
|
3
|
+
|
|
4
|
+
This module provides abstract base classes that define the common interface
|
|
5
|
+
for spatial indexing data structures like KD-trees, VP-trees, R-trees, and
|
|
6
|
+
Cover trees.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from typing import Callable, List, NamedTuple, Optional
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
from numpy.typing import ArrayLike, NDArray
|
|
15
|
+
|
|
16
|
+
# Module logger
|
|
17
|
+
_logger = logging.getLogger("pytcl.containers")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SpatialQueryResult(NamedTuple):
|
|
21
|
+
"""Result of a spatial query.
|
|
22
|
+
|
|
23
|
+
Attributes
|
|
24
|
+
----------
|
|
25
|
+
indices : ndarray
|
|
26
|
+
Indices of matching points in the original data.
|
|
27
|
+
distances : ndarray
|
|
28
|
+
Distances to matching points.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
indices: NDArray[np.intp]
|
|
32
|
+
distances: NDArray[np.floating]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class BaseSpatialIndex(ABC):
|
|
36
|
+
"""
|
|
37
|
+
Abstract base class for spatial indexing data structures.
|
|
38
|
+
|
|
39
|
+
All spatial index implementations (KDTree, VPTree, RTree, CoverTree)
|
|
40
|
+
should inherit from this class and implement the required methods.
|
|
41
|
+
|
|
42
|
+
This provides a consistent interface for:
|
|
43
|
+
- Building the index from point data
|
|
44
|
+
- k-nearest neighbor queries
|
|
45
|
+
- Range/radius queries
|
|
46
|
+
- Dimension and size introspection
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
data : array_like
|
|
51
|
+
Data points of shape (n_samples, n_features).
|
|
52
|
+
|
|
53
|
+
Attributes
|
|
54
|
+
----------
|
|
55
|
+
data : ndarray
|
|
56
|
+
The indexed data points.
|
|
57
|
+
n_samples : int
|
|
58
|
+
Number of data points.
|
|
59
|
+
n_features : int
|
|
60
|
+
Dimensionality of data points.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, data: ArrayLike):
|
|
64
|
+
self.data = np.asarray(data, dtype=np.float64)
|
|
65
|
+
|
|
66
|
+
if self.data.ndim != 2:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"Data must be 2-dimensional (n_samples, n_features), "
|
|
69
|
+
f"got shape {self.data.shape}"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
self.n_samples, self.n_features = self.data.shape
|
|
73
|
+
_logger.debug(
|
|
74
|
+
"%s initialized with %d points in %d dimensions",
|
|
75
|
+
self.__class__.__name__,
|
|
76
|
+
self.n_samples,
|
|
77
|
+
self.n_features,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
@abstractmethod
|
|
81
|
+
def query(
|
|
82
|
+
self,
|
|
83
|
+
X: ArrayLike,
|
|
84
|
+
k: int = 1,
|
|
85
|
+
) -> SpatialQueryResult:
|
|
86
|
+
"""
|
|
87
|
+
Query the index for k nearest neighbors.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
X : array_like
|
|
92
|
+
Query points of shape (n_queries, n_features) or (n_features,).
|
|
93
|
+
k : int, optional
|
|
94
|
+
Number of nearest neighbors to return. Default is 1.
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
result : SpatialQueryResult
|
|
99
|
+
Named tuple with indices and distances of k nearest neighbors
|
|
100
|
+
for each query point.
|
|
101
|
+
"""
|
|
102
|
+
pass
|
|
103
|
+
|
|
104
|
+
@abstractmethod
|
|
105
|
+
def query_radius(
|
|
106
|
+
self,
|
|
107
|
+
X: ArrayLike,
|
|
108
|
+
r: float,
|
|
109
|
+
) -> List[List[int]]:
|
|
110
|
+
"""
|
|
111
|
+
Query the index for all points within radius r.
|
|
112
|
+
|
|
113
|
+
Parameters
|
|
114
|
+
----------
|
|
115
|
+
X : array_like
|
|
116
|
+
Query points of shape (n_queries, n_features) or (n_features,).
|
|
117
|
+
r : float
|
|
118
|
+
Search radius.
|
|
119
|
+
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
indices : list of list of int
|
|
123
|
+
For each query point, a list of indices of data points
|
|
124
|
+
within distance r.
|
|
125
|
+
"""
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
def __len__(self) -> int:
|
|
129
|
+
"""Return number of indexed points."""
|
|
130
|
+
return self.n_samples
|
|
131
|
+
|
|
132
|
+
def __repr__(self) -> str:
|
|
133
|
+
return (
|
|
134
|
+
f"{self.__class__.__name__}("
|
|
135
|
+
f"n_samples={self.n_samples}, n_features={self.n_features})"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class MetricSpatialIndex(BaseSpatialIndex):
|
|
140
|
+
"""
|
|
141
|
+
Base class for metric space spatial indices.
|
|
142
|
+
|
|
143
|
+
Extends BaseSpatialIndex with support for custom distance metrics.
|
|
144
|
+
Used by VP-trees and Cover trees which can work with any metric.
|
|
145
|
+
|
|
146
|
+
Parameters
|
|
147
|
+
----------
|
|
148
|
+
data : array_like
|
|
149
|
+
Data points of shape (n_samples, n_features).
|
|
150
|
+
metric : callable, optional
|
|
151
|
+
Distance function with signature metric(x, y) -> float.
|
|
152
|
+
Default is Euclidean distance.
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
def __init__(
|
|
156
|
+
self,
|
|
157
|
+
data: ArrayLike,
|
|
158
|
+
metric: Optional[Callable[[NDArray, NDArray], float]] = None,
|
|
159
|
+
):
|
|
160
|
+
super().__init__(data)
|
|
161
|
+
|
|
162
|
+
if metric is None:
|
|
163
|
+
self.metric = self._euclidean_distance
|
|
164
|
+
else:
|
|
165
|
+
self.metric = metric
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def _euclidean_distance(x: NDArray, y: NDArray) -> float:
|
|
169
|
+
"""Default Euclidean distance metric."""
|
|
170
|
+
return float(np.sqrt(np.sum((x - y) ** 2)))
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def validate_query_input(
|
|
174
|
+
X: ArrayLike,
|
|
175
|
+
n_features: int,
|
|
176
|
+
) -> NDArray[np.floating]:
|
|
177
|
+
"""
|
|
178
|
+
Validate and reshape query input.
|
|
179
|
+
|
|
180
|
+
Parameters
|
|
181
|
+
----------
|
|
182
|
+
X : array_like
|
|
183
|
+
Query points.
|
|
184
|
+
n_features : int
|
|
185
|
+
Expected number of features.
|
|
186
|
+
|
|
187
|
+
Returns
|
|
188
|
+
-------
|
|
189
|
+
X : ndarray
|
|
190
|
+
Validated query array of shape (n_queries, n_features).
|
|
191
|
+
|
|
192
|
+
Raises
|
|
193
|
+
------
|
|
194
|
+
ValueError
|
|
195
|
+
If query has wrong number of features.
|
|
196
|
+
"""
|
|
197
|
+
X = np.asarray(X, dtype=np.float64)
|
|
198
|
+
|
|
199
|
+
if X.ndim == 1:
|
|
200
|
+
X = X.reshape(1, -1)
|
|
201
|
+
|
|
202
|
+
if X.shape[1] != n_features:
|
|
203
|
+
_logger.warning(
|
|
204
|
+
"Query feature mismatch: got %d, expected %d", X.shape[1], n_features
|
|
205
|
+
)
|
|
206
|
+
raise ValueError(f"Query has {X.shape[1]} features, expected {n_features}")
|
|
207
|
+
|
|
208
|
+
_logger.debug(
|
|
209
|
+
"Validated query input: %d queries, %d features", X.shape[0], X.shape[1]
|
|
210
|
+
)
|
|
211
|
+
return X
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
__all__ = [
|
|
215
|
+
"SpatialQueryResult",
|
|
216
|
+
"BaseSpatialIndex",
|
|
217
|
+
"MetricSpatialIndex",
|
|
218
|
+
"validate_query_input",
|
|
219
|
+
]
|
pytcl/containers/covertree.py
CHANGED
|
@@ -11,11 +11,17 @@ References
|
|
|
11
11
|
neighbor," ICML 2006.
|
|
12
12
|
"""
|
|
13
13
|
|
|
14
|
+
import logging
|
|
14
15
|
from typing import Callable, List, NamedTuple, Optional, Set, Tuple
|
|
15
16
|
|
|
16
17
|
import numpy as np
|
|
17
18
|
from numpy.typing import ArrayLike, NDArray
|
|
18
19
|
|
|
20
|
+
from pytcl.containers.base import MetricSpatialIndex, validate_query_input
|
|
21
|
+
|
|
22
|
+
# Module logger
|
|
23
|
+
_logger = logging.getLogger("pytcl.containers.covertree")
|
|
24
|
+
|
|
19
25
|
|
|
20
26
|
class CoverTreeResult(NamedTuple):
|
|
21
27
|
"""Result of Cover tree query.
|
|
@@ -60,7 +66,7 @@ class CoverTreeNode:
|
|
|
60
66
|
self.children[level].append(child)
|
|
61
67
|
|
|
62
68
|
|
|
63
|
-
class CoverTree:
|
|
69
|
+
class CoverTree(MetricSpatialIndex):
|
|
64
70
|
"""
|
|
65
71
|
Cover Tree for metric space nearest neighbor search.
|
|
66
72
|
|
|
@@ -93,6 +99,11 @@ class CoverTree:
|
|
|
93
99
|
|
|
94
100
|
The implementation uses a simplified version of the original
|
|
95
101
|
algorithm for clarity.
|
|
102
|
+
|
|
103
|
+
See Also
|
|
104
|
+
--------
|
|
105
|
+
MetricSpatialIndex : Abstract base class for metric-based spatial indices.
|
|
106
|
+
VPTree : Alternative metric space index using vantage points.
|
|
96
107
|
"""
|
|
97
108
|
|
|
98
109
|
def __init__(
|
|
@@ -101,19 +112,9 @@ class CoverTree:
|
|
|
101
112
|
metric: Optional[Callable[[NDArray, NDArray], float]] = None,
|
|
102
113
|
base: float = 2.0,
|
|
103
114
|
):
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
if self.data.ndim != 2:
|
|
107
|
-
raise ValueError("Data must be 2-dimensional")
|
|
108
|
-
|
|
109
|
-
self.n_samples, self.n_features = self.data.shape
|
|
115
|
+
super().__init__(data, metric)
|
|
110
116
|
self.base = base
|
|
111
117
|
|
|
112
|
-
if metric is None:
|
|
113
|
-
self.metric = self._euclidean_distance
|
|
114
|
-
else:
|
|
115
|
-
self.metric = metric
|
|
116
|
-
|
|
117
118
|
# Compute distance cache for small datasets
|
|
118
119
|
self._distance_cache: dict[Tuple[int, int], float] = {}
|
|
119
120
|
|
|
@@ -124,10 +125,12 @@ class CoverTree:
|
|
|
124
125
|
|
|
125
126
|
if self.n_samples > 0:
|
|
126
127
|
self._build_tree()
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
128
|
+
_logger.debug(
|
|
129
|
+
"CoverTree built with base=%.1f, levels=%d to %d",
|
|
130
|
+
base,
|
|
131
|
+
self.min_level,
|
|
132
|
+
self.max_level,
|
|
133
|
+
)
|
|
131
134
|
|
|
132
135
|
def _distance(self, i: int, j: int) -> float:
|
|
133
136
|
"""Get distance between points i and j (with caching)."""
|
|
@@ -245,11 +248,7 @@ class CoverTree:
|
|
|
245
248
|
result : CoverTreeResult
|
|
246
249
|
Indices and distances of k nearest neighbors.
|
|
247
250
|
"""
|
|
248
|
-
X =
|
|
249
|
-
|
|
250
|
-
if X.ndim == 1:
|
|
251
|
-
X = X.reshape(1, -1)
|
|
252
|
-
|
|
251
|
+
X = validate_query_input(X, self.n_features)
|
|
253
252
|
n_queries = X.shape[0]
|
|
254
253
|
|
|
255
254
|
all_indices = np.zeros((n_queries, k), dtype=np.intp)
|
|
@@ -368,11 +367,7 @@ class CoverTree:
|
|
|
368
367
|
indices : list of lists
|
|
369
368
|
For each query, list of indices within radius.
|
|
370
369
|
"""
|
|
371
|
-
X =
|
|
372
|
-
|
|
373
|
-
if X.ndim == 1:
|
|
374
|
-
X = X.reshape(1, -1)
|
|
375
|
-
|
|
370
|
+
X = validate_query_input(X, self.n_features)
|
|
376
371
|
n_queries = X.shape[0]
|
|
377
372
|
results: List[List[int]] = []
|
|
378
373
|
|
pytcl/containers/kd_tree.py
CHANGED
|
@@ -13,11 +13,17 @@ References
|
|
|
13
13
|
Finding Best Matches in Logarithmic Expected Time," ACM TOMS, 1977.
|
|
14
14
|
"""
|
|
15
15
|
|
|
16
|
+
import logging
|
|
16
17
|
from typing import List, NamedTuple, Optional, Tuple
|
|
17
18
|
|
|
18
19
|
import numpy as np
|
|
19
20
|
from numpy.typing import ArrayLike, NDArray
|
|
20
21
|
|
|
22
|
+
from pytcl.containers.base import BaseSpatialIndex, validate_query_input
|
|
23
|
+
|
|
24
|
+
# Module logger
|
|
25
|
+
_logger = logging.getLogger("pytcl.containers.kd_tree")
|
|
26
|
+
|
|
21
27
|
|
|
22
28
|
class KDNode:
|
|
23
29
|
"""A node in the k-d tree.
|
|
@@ -66,7 +72,7 @@ class NearestNeighborResult(NamedTuple):
|
|
|
66
72
|
distances: NDArray[np.floating]
|
|
67
73
|
|
|
68
74
|
|
|
69
|
-
class KDTree:
|
|
75
|
+
class KDTree(BaseSpatialIndex):
|
|
70
76
|
"""
|
|
71
77
|
K-D Tree for efficient spatial queries.
|
|
72
78
|
|
|
@@ -97,6 +103,11 @@ class KDTree:
|
|
|
97
103
|
|
|
98
104
|
Query complexity is O(log n) on average for nearest neighbor search,
|
|
99
105
|
though worst case is O(n) for highly unbalanced queries.
|
|
106
|
+
|
|
107
|
+
See Also
|
|
108
|
+
--------
|
|
109
|
+
BaseSpatialIndex : Abstract base class defining the spatial index interface.
|
|
110
|
+
BallTree : Alternative spatial index using hyperspheres.
|
|
100
111
|
"""
|
|
101
112
|
|
|
102
113
|
def __init__(
|
|
@@ -104,17 +115,13 @@ class KDTree:
|
|
|
104
115
|
data: ArrayLike,
|
|
105
116
|
leaf_size: int = 10,
|
|
106
117
|
):
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
if self.data.ndim != 2:
|
|
110
|
-
raise ValueError("Data must be 2-dimensional (n_samples, n_features)")
|
|
111
|
-
|
|
112
|
-
self.n_samples, self.n_features = self.data.shape
|
|
118
|
+
super().__init__(data)
|
|
113
119
|
self.leaf_size = leaf_size
|
|
114
120
|
|
|
115
121
|
# Build the tree
|
|
116
122
|
indices = np.arange(self.n_samples)
|
|
117
123
|
self.root = self._build_tree(indices, depth=0)
|
|
124
|
+
_logger.debug("KDTree built with leaf_size=%d", leaf_size)
|
|
118
125
|
|
|
119
126
|
def _build_tree(
|
|
120
127
|
self,
|
|
@@ -173,12 +180,9 @@ class KDTree:
|
|
|
173
180
|
>>> result.indices
|
|
174
181
|
array([[0, 1]])
|
|
175
182
|
"""
|
|
176
|
-
X =
|
|
177
|
-
|
|
178
|
-
if X.ndim == 1:
|
|
179
|
-
X = X.reshape(1, -1)
|
|
180
|
-
|
|
183
|
+
X = validate_query_input(X, self.n_features)
|
|
181
184
|
n_queries = X.shape[0]
|
|
185
|
+
_logger.debug("KDTree.query: %d queries, k=%d", n_queries, k)
|
|
182
186
|
|
|
183
187
|
all_indices = np.zeros((n_queries, k), dtype=np.intp)
|
|
184
188
|
all_distances = np.full((n_queries, k), np.inf)
|
|
@@ -263,11 +267,7 @@ class KDTree:
|
|
|
263
267
|
>>> tree.query_radius([[0, 0]], r=1.5)
|
|
264
268
|
[[0, 1, 2]]
|
|
265
269
|
"""
|
|
266
|
-
X =
|
|
267
|
-
|
|
268
|
-
if X.ndim == 1:
|
|
269
|
-
X = X.reshape(1, -1)
|
|
270
|
-
|
|
270
|
+
X = validate_query_input(X, self.n_features)
|
|
271
271
|
n_queries = X.shape[0]
|
|
272
272
|
results: List[List[int]] = []
|
|
273
273
|
|
|
@@ -331,7 +331,7 @@ class KDTree:
|
|
|
331
331
|
return self.query_radius(X, r)
|
|
332
332
|
|
|
333
333
|
|
|
334
|
-
class BallTree:
|
|
334
|
+
class BallTree(BaseSpatialIndex):
|
|
335
335
|
"""
|
|
336
336
|
Ball Tree for efficient spatial queries.
|
|
337
337
|
|
|
@@ -357,6 +357,11 @@ class BallTree:
|
|
|
357
357
|
-----
|
|
358
358
|
Ball trees have O(n log n) construction and O(log n) average-case
|
|
359
359
|
query time. They can outperform k-d trees in high dimensions.
|
|
360
|
+
|
|
361
|
+
See Also
|
|
362
|
+
--------
|
|
363
|
+
BaseSpatialIndex : Abstract base class defining the spatial index interface.
|
|
364
|
+
KDTree : Alternative spatial index using axis-aligned splits.
|
|
360
365
|
"""
|
|
361
366
|
|
|
362
367
|
def __init__(
|
|
@@ -364,12 +369,7 @@ class BallTree:
|
|
|
364
369
|
data: ArrayLike,
|
|
365
370
|
leaf_size: int = 10,
|
|
366
371
|
):
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
if self.data.ndim != 2:
|
|
370
|
-
raise ValueError("Data must be 2-dimensional")
|
|
371
|
-
|
|
372
|
-
self.n_samples, self.n_features = self.data.shape
|
|
372
|
+
super().__init__(data)
|
|
373
373
|
self.leaf_size = leaf_size
|
|
374
374
|
|
|
375
375
|
# Build tree using indices
|
|
@@ -382,6 +382,7 @@ class BallTree:
|
|
|
382
382
|
self._leaf_indices: List[Optional[NDArray[np.intp]]] = []
|
|
383
383
|
|
|
384
384
|
self._build_tree(self._indices)
|
|
385
|
+
_logger.debug("BallTree built with leaf_size=%d", leaf_size)
|
|
385
386
|
|
|
386
387
|
def _build_tree(
|
|
387
388
|
self,
|
|
@@ -459,11 +460,7 @@ class BallTree:
|
|
|
459
460
|
result : NearestNeighborResult
|
|
460
461
|
Indices and distances of k nearest neighbors.
|
|
461
462
|
"""
|
|
462
|
-
X =
|
|
463
|
-
|
|
464
|
-
if X.ndim == 1:
|
|
465
|
-
X = X.reshape(1, -1)
|
|
466
|
-
|
|
463
|
+
X = validate_query_input(X, self.n_features)
|
|
467
464
|
n_queries = X.shape[0]
|
|
468
465
|
all_indices = np.zeros((n_queries, k), dtype=np.intp)
|
|
469
466
|
all_distances = np.full((n_queries, k), np.inf)
|
|
@@ -478,6 +475,74 @@ class BallTree:
|
|
|
478
475
|
|
|
479
476
|
return NearestNeighborResult(indices=all_indices, distances=all_distances)
|
|
480
477
|
|
|
478
|
+
def query_radius(
|
|
479
|
+
self,
|
|
480
|
+
X: ArrayLike,
|
|
481
|
+
r: float,
|
|
482
|
+
) -> List[List[int]]:
|
|
483
|
+
"""
|
|
484
|
+
Query the tree for all points within radius r.
|
|
485
|
+
|
|
486
|
+
Parameters
|
|
487
|
+
----------
|
|
488
|
+
X : array_like
|
|
489
|
+
Query points of shape (n_queries, n_features) or (n_features,).
|
|
490
|
+
r : float
|
|
491
|
+
Query radius.
|
|
492
|
+
|
|
493
|
+
Returns
|
|
494
|
+
-------
|
|
495
|
+
indices : list of lists
|
|
496
|
+
For each query, a list of indices of points within radius r.
|
|
497
|
+
"""
|
|
498
|
+
X = validate_query_input(X, self.n_features)
|
|
499
|
+
n_queries = X.shape[0]
|
|
500
|
+
results: List[List[int]] = []
|
|
501
|
+
|
|
502
|
+
for i in range(n_queries):
|
|
503
|
+
indices = self._query_radius_single(X[i], r)
|
|
504
|
+
results.append(indices)
|
|
505
|
+
|
|
506
|
+
return results
|
|
507
|
+
|
|
508
|
+
def _query_radius_single(
|
|
509
|
+
self,
|
|
510
|
+
query: NDArray[np.floating],
|
|
511
|
+
r: float,
|
|
512
|
+
) -> List[int]:
|
|
513
|
+
"""Find all points within radius r of query point."""
|
|
514
|
+
indices: List[int] = []
|
|
515
|
+
|
|
516
|
+
def _search(node_id: int) -> None:
|
|
517
|
+
if node_id < 0:
|
|
518
|
+
return
|
|
519
|
+
|
|
520
|
+
centroid = self._centroids[node_id]
|
|
521
|
+
radius = self._radii[node_id]
|
|
522
|
+
|
|
523
|
+
# Distance to ball surface
|
|
524
|
+
dist_to_center = np.sqrt(np.sum((query - centroid) ** 2))
|
|
525
|
+
|
|
526
|
+
# Prune if ball is farther than radius
|
|
527
|
+
if dist_to_center - radius > r:
|
|
528
|
+
return
|
|
529
|
+
|
|
530
|
+
if self._is_leaf[node_id]:
|
|
531
|
+
# Check all points in leaf
|
|
532
|
+
leaf_indices = self._leaf_indices[node_id]
|
|
533
|
+
if leaf_indices is not None:
|
|
534
|
+
for idx in leaf_indices:
|
|
535
|
+
dist = np.sqrt(np.sum((query - self.data[idx]) ** 2))
|
|
536
|
+
if dist <= r:
|
|
537
|
+
indices.append(idx)
|
|
538
|
+
else:
|
|
539
|
+
# Visit both children
|
|
540
|
+
_search(self._left[node_id])
|
|
541
|
+
_search(self._right[node_id])
|
|
542
|
+
|
|
543
|
+
_search(0)
|
|
544
|
+
return indices
|
|
545
|
+
|
|
481
546
|
def _query_single(
|
|
482
547
|
self,
|
|
483
548
|
query: NDArray[np.floating],
|