nrl-tracker 1.1.3__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,219 @@
1
+ """
2
+ Base classes for spatial data structures.
3
+
4
+ This module provides abstract base classes that define the common interface
5
+ for spatial indexing data structures like KD-trees, VP-trees, R-trees, and
6
+ Cover trees.
7
+ """
8
+
9
+ import logging
10
+ from abc import ABC, abstractmethod
11
+ from typing import Callable, List, NamedTuple, Optional
12
+
13
+ import numpy as np
14
+ from numpy.typing import ArrayLike, NDArray
15
+
16
+ # Module logger
17
+ _logger = logging.getLogger("pytcl.containers")
18
+
19
+
20
+ class SpatialQueryResult(NamedTuple):
21
+ """Result of a spatial query.
22
+
23
+ Attributes
24
+ ----------
25
+ indices : ndarray
26
+ Indices of matching points in the original data.
27
+ distances : ndarray
28
+ Distances to matching points.
29
+ """
30
+
31
+ indices: NDArray[np.intp]
32
+ distances: NDArray[np.floating]
33
+
34
+
35
+ class BaseSpatialIndex(ABC):
36
+ """
37
+ Abstract base class for spatial indexing data structures.
38
+
39
+ All spatial index implementations (KDTree, VPTree, RTree, CoverTree)
40
+ should inherit from this class and implement the required methods.
41
+
42
+ This provides a consistent interface for:
43
+ - Building the index from point data
44
+ - k-nearest neighbor queries
45
+ - Range/radius queries
46
+ - Dimension and size introspection
47
+
48
+ Parameters
49
+ ----------
50
+ data : array_like
51
+ Data points of shape (n_samples, n_features).
52
+
53
+ Attributes
54
+ ----------
55
+ data : ndarray
56
+ The indexed data points.
57
+ n_samples : int
58
+ Number of data points.
59
+ n_features : int
60
+ Dimensionality of data points.
61
+ """
62
+
63
+ def __init__(self, data: ArrayLike):
64
+ self.data = np.asarray(data, dtype=np.float64)
65
+
66
+ if self.data.ndim != 2:
67
+ raise ValueError(
68
+ f"Data must be 2-dimensional (n_samples, n_features), "
69
+ f"got shape {self.data.shape}"
70
+ )
71
+
72
+ self.n_samples, self.n_features = self.data.shape
73
+ _logger.debug(
74
+ "%s initialized with %d points in %d dimensions",
75
+ self.__class__.__name__,
76
+ self.n_samples,
77
+ self.n_features,
78
+ )
79
+
80
+ @abstractmethod
81
+ def query(
82
+ self,
83
+ X: ArrayLike,
84
+ k: int = 1,
85
+ ) -> SpatialQueryResult:
86
+ """
87
+ Query the index for k nearest neighbors.
88
+
89
+ Parameters
90
+ ----------
91
+ X : array_like
92
+ Query points of shape (n_queries, n_features) or (n_features,).
93
+ k : int, optional
94
+ Number of nearest neighbors to return. Default is 1.
95
+
96
+ Returns
97
+ -------
98
+ result : SpatialQueryResult
99
+ Named tuple with indices and distances of k nearest neighbors
100
+ for each query point.
101
+ """
102
+ pass
103
+
104
+ @abstractmethod
105
+ def query_radius(
106
+ self,
107
+ X: ArrayLike,
108
+ r: float,
109
+ ) -> List[List[int]]:
110
+ """
111
+ Query the index for all points within radius r.
112
+
113
+ Parameters
114
+ ----------
115
+ X : array_like
116
+ Query points of shape (n_queries, n_features) or (n_features,).
117
+ r : float
118
+ Search radius.
119
+
120
+ Returns
121
+ -------
122
+ indices : list of list of int
123
+ For each query point, a list of indices of data points
124
+ within distance r.
125
+ """
126
+ pass
127
+
128
+ def __len__(self) -> int:
129
+ """Return number of indexed points."""
130
+ return self.n_samples
131
+
132
+ def __repr__(self) -> str:
133
+ return (
134
+ f"{self.__class__.__name__}("
135
+ f"n_samples={self.n_samples}, n_features={self.n_features})"
136
+ )
137
+
138
+
139
+ class MetricSpatialIndex(BaseSpatialIndex):
140
+ """
141
+ Base class for metric space spatial indices.
142
+
143
+ Extends BaseSpatialIndex with support for custom distance metrics.
144
+ Used by VP-trees and Cover trees which can work with any metric.
145
+
146
+ Parameters
147
+ ----------
148
+ data : array_like
149
+ Data points of shape (n_samples, n_features).
150
+ metric : callable, optional
151
+ Distance function with signature metric(x, y) -> float.
152
+ Default is Euclidean distance.
153
+ """
154
+
155
+ def __init__(
156
+ self,
157
+ data: ArrayLike,
158
+ metric: Optional[Callable[[NDArray, NDArray], float]] = None,
159
+ ):
160
+ super().__init__(data)
161
+
162
+ if metric is None:
163
+ self.metric = self._euclidean_distance
164
+ else:
165
+ self.metric = metric
166
+
167
+ @staticmethod
168
+ def _euclidean_distance(x: NDArray, y: NDArray) -> float:
169
+ """Default Euclidean distance metric."""
170
+ return float(np.sqrt(np.sum((x - y) ** 2)))
171
+
172
+
173
+ def validate_query_input(
174
+ X: ArrayLike,
175
+ n_features: int,
176
+ ) -> NDArray[np.floating]:
177
+ """
178
+ Validate and reshape query input.
179
+
180
+ Parameters
181
+ ----------
182
+ X : array_like
183
+ Query points.
184
+ n_features : int
185
+ Expected number of features.
186
+
187
+ Returns
188
+ -------
189
+ X : ndarray
190
+ Validated query array of shape (n_queries, n_features).
191
+
192
+ Raises
193
+ ------
194
+ ValueError
195
+ If query has wrong number of features.
196
+ """
197
+ X = np.asarray(X, dtype=np.float64)
198
+
199
+ if X.ndim == 1:
200
+ X = X.reshape(1, -1)
201
+
202
+ if X.shape[1] != n_features:
203
+ _logger.warning(
204
+ "Query feature mismatch: got %d, expected %d", X.shape[1], n_features
205
+ )
206
+ raise ValueError(f"Query has {X.shape[1]} features, expected {n_features}")
207
+
208
+ _logger.debug(
209
+ "Validated query input: %d queries, %d features", X.shape[0], X.shape[1]
210
+ )
211
+ return X
212
+
213
+
214
+ __all__ = [
215
+ "SpatialQueryResult",
216
+ "BaseSpatialIndex",
217
+ "MetricSpatialIndex",
218
+ "validate_query_input",
219
+ ]
@@ -11,11 +11,17 @@ References
11
11
  neighbor," ICML 2006.
12
12
  """
13
13
 
14
+ import logging
14
15
  from typing import Callable, List, NamedTuple, Optional, Set, Tuple
15
16
 
16
17
  import numpy as np
17
18
  from numpy.typing import ArrayLike, NDArray
18
19
 
20
+ from pytcl.containers.base import MetricSpatialIndex, validate_query_input
21
+
22
+ # Module logger
23
+ _logger = logging.getLogger("pytcl.containers.covertree")
24
+
19
25
 
20
26
  class CoverTreeResult(NamedTuple):
21
27
  """Result of Cover tree query.
@@ -60,7 +66,7 @@ class CoverTreeNode:
60
66
  self.children[level].append(child)
61
67
 
62
68
 
63
- class CoverTree:
69
+ class CoverTree(MetricSpatialIndex):
64
70
  """
65
71
  Cover Tree for metric space nearest neighbor search.
66
72
 
@@ -93,6 +99,11 @@ class CoverTree:
93
99
 
94
100
  The implementation uses a simplified version of the original
95
101
  algorithm for clarity.
102
+
103
+ See Also
104
+ --------
105
+ MetricSpatialIndex : Abstract base class for metric-based spatial indices.
106
+ VPTree : Alternative metric space index using vantage points.
96
107
  """
97
108
 
98
109
  def __init__(
@@ -101,19 +112,9 @@ class CoverTree:
101
112
  metric: Optional[Callable[[NDArray, NDArray], float]] = None,
102
113
  base: float = 2.0,
103
114
  ):
104
- self.data = np.asarray(data, dtype=np.float64)
105
-
106
- if self.data.ndim != 2:
107
- raise ValueError("Data must be 2-dimensional")
108
-
109
- self.n_samples, self.n_features = self.data.shape
115
+ super().__init__(data, metric)
110
116
  self.base = base
111
117
 
112
- if metric is None:
113
- self.metric = self._euclidean_distance
114
- else:
115
- self.metric = metric
116
-
117
118
  # Compute distance cache for small datasets
118
119
  self._distance_cache: dict[Tuple[int, int], float] = {}
119
120
 
@@ -124,10 +125,12 @@ class CoverTree:
124
125
 
125
126
  if self.n_samples > 0:
126
127
  self._build_tree()
127
-
128
- def _euclidean_distance(self, x: NDArray, y: NDArray) -> float:
129
- """Default Euclidean distance metric."""
130
- return float(np.sqrt(np.sum((x - y) ** 2)))
128
+ _logger.debug(
129
+ "CoverTree built with base=%.1f, levels=%d to %d",
130
+ base,
131
+ self.min_level,
132
+ self.max_level,
133
+ )
131
134
 
132
135
  def _distance(self, i: int, j: int) -> float:
133
136
  """Get distance between points i and j (with caching)."""
@@ -245,11 +248,7 @@ class CoverTree:
245
248
  result : CoverTreeResult
246
249
  Indices and distances of k nearest neighbors.
247
250
  """
248
- X = np.asarray(X, dtype=np.float64)
249
-
250
- if X.ndim == 1:
251
- X = X.reshape(1, -1)
252
-
251
+ X = validate_query_input(X, self.n_features)
253
252
  n_queries = X.shape[0]
254
253
 
255
254
  all_indices = np.zeros((n_queries, k), dtype=np.intp)
@@ -368,11 +367,7 @@ class CoverTree:
368
367
  indices : list of lists
369
368
  For each query, list of indices within radius.
370
369
  """
371
- X = np.asarray(X, dtype=np.float64)
372
-
373
- if X.ndim == 1:
374
- X = X.reshape(1, -1)
375
-
370
+ X = validate_query_input(X, self.n_features)
376
371
  n_queries = X.shape[0]
377
372
  results: List[List[int]] = []
378
373
 
@@ -13,11 +13,17 @@ References
13
13
  Finding Best Matches in Logarithmic Expected Time," ACM TOMS, 1977.
14
14
  """
15
15
 
16
+ import logging
16
17
  from typing import List, NamedTuple, Optional, Tuple
17
18
 
18
19
  import numpy as np
19
20
  from numpy.typing import ArrayLike, NDArray
20
21
 
22
+ from pytcl.containers.base import BaseSpatialIndex, validate_query_input
23
+
24
+ # Module logger
25
+ _logger = logging.getLogger("pytcl.containers.kd_tree")
26
+
21
27
 
22
28
  class KDNode:
23
29
  """A node in the k-d tree.
@@ -66,7 +72,7 @@ class NearestNeighborResult(NamedTuple):
66
72
  distances: NDArray[np.floating]
67
73
 
68
74
 
69
- class KDTree:
75
+ class KDTree(BaseSpatialIndex):
70
76
  """
71
77
  K-D Tree for efficient spatial queries.
72
78
 
@@ -97,6 +103,11 @@ class KDTree:
97
103
 
98
104
  Query complexity is O(log n) on average for nearest neighbor search,
99
105
  though worst case is O(n) for highly unbalanced queries.
106
+
107
+ See Also
108
+ --------
109
+ BaseSpatialIndex : Abstract base class defining the spatial index interface.
110
+ BallTree : Alternative spatial index using hyperspheres.
100
111
  """
101
112
 
102
113
  def __init__(
@@ -104,17 +115,13 @@ class KDTree:
104
115
  data: ArrayLike,
105
116
  leaf_size: int = 10,
106
117
  ):
107
- self.data = np.asarray(data, dtype=np.float64)
108
-
109
- if self.data.ndim != 2:
110
- raise ValueError("Data must be 2-dimensional (n_samples, n_features)")
111
-
112
- self.n_samples, self.n_features = self.data.shape
118
+ super().__init__(data)
113
119
  self.leaf_size = leaf_size
114
120
 
115
121
  # Build the tree
116
122
  indices = np.arange(self.n_samples)
117
123
  self.root = self._build_tree(indices, depth=0)
124
+ _logger.debug("KDTree built with leaf_size=%d", leaf_size)
118
125
 
119
126
  def _build_tree(
120
127
  self,
@@ -173,12 +180,9 @@ class KDTree:
173
180
  >>> result.indices
174
181
  array([[0, 1]])
175
182
  """
176
- X = np.asarray(X, dtype=np.float64)
177
-
178
- if X.ndim == 1:
179
- X = X.reshape(1, -1)
180
-
183
+ X = validate_query_input(X, self.n_features)
181
184
  n_queries = X.shape[0]
185
+ _logger.debug("KDTree.query: %d queries, k=%d", n_queries, k)
182
186
 
183
187
  all_indices = np.zeros((n_queries, k), dtype=np.intp)
184
188
  all_distances = np.full((n_queries, k), np.inf)
@@ -263,11 +267,7 @@ class KDTree:
263
267
  >>> tree.query_radius([[0, 0]], r=1.5)
264
268
  [[0, 1, 2]]
265
269
  """
266
- X = np.asarray(X, dtype=np.float64)
267
-
268
- if X.ndim == 1:
269
- X = X.reshape(1, -1)
270
-
270
+ X = validate_query_input(X, self.n_features)
271
271
  n_queries = X.shape[0]
272
272
  results: List[List[int]] = []
273
273
 
@@ -331,7 +331,7 @@ class KDTree:
331
331
  return self.query_radius(X, r)
332
332
 
333
333
 
334
- class BallTree:
334
+ class BallTree(BaseSpatialIndex):
335
335
  """
336
336
  Ball Tree for efficient spatial queries.
337
337
 
@@ -357,6 +357,11 @@ class BallTree:
357
357
  -----
358
358
  Ball trees have O(n log n) construction and O(log n) average-case
359
359
  query time. They can outperform k-d trees in high dimensions.
360
+
361
+ See Also
362
+ --------
363
+ BaseSpatialIndex : Abstract base class defining the spatial index interface.
364
+ KDTree : Alternative spatial index using axis-aligned splits.
360
365
  """
361
366
 
362
367
  def __init__(
@@ -364,12 +369,7 @@ class BallTree:
364
369
  data: ArrayLike,
365
370
  leaf_size: int = 10,
366
371
  ):
367
- self.data = np.asarray(data, dtype=np.float64)
368
-
369
- if self.data.ndim != 2:
370
- raise ValueError("Data must be 2-dimensional")
371
-
372
- self.n_samples, self.n_features = self.data.shape
372
+ super().__init__(data)
373
373
  self.leaf_size = leaf_size
374
374
 
375
375
  # Build tree using indices
@@ -382,6 +382,7 @@ class BallTree:
382
382
  self._leaf_indices: List[Optional[NDArray[np.intp]]] = []
383
383
 
384
384
  self._build_tree(self._indices)
385
+ _logger.debug("BallTree built with leaf_size=%d", leaf_size)
385
386
 
386
387
  def _build_tree(
387
388
  self,
@@ -459,11 +460,7 @@ class BallTree:
459
460
  result : NearestNeighborResult
460
461
  Indices and distances of k nearest neighbors.
461
462
  """
462
- X = np.asarray(X, dtype=np.float64)
463
-
464
- if X.ndim == 1:
465
- X = X.reshape(1, -1)
466
-
463
+ X = validate_query_input(X, self.n_features)
467
464
  n_queries = X.shape[0]
468
465
  all_indices = np.zeros((n_queries, k), dtype=np.intp)
469
466
  all_distances = np.full((n_queries, k), np.inf)
@@ -478,6 +475,74 @@ class BallTree:
478
475
 
479
476
  return NearestNeighborResult(indices=all_indices, distances=all_distances)
480
477
 
478
+ def query_radius(
479
+ self,
480
+ X: ArrayLike,
481
+ r: float,
482
+ ) -> List[List[int]]:
483
+ """
484
+ Query the tree for all points within radius r.
485
+
486
+ Parameters
487
+ ----------
488
+ X : array_like
489
+ Query points of shape (n_queries, n_features) or (n_features,).
490
+ r : float
491
+ Query radius.
492
+
493
+ Returns
494
+ -------
495
+ indices : list of lists
496
+ For each query, a list of indices of points within radius r.
497
+ """
498
+ X = validate_query_input(X, self.n_features)
499
+ n_queries = X.shape[0]
500
+ results: List[List[int]] = []
501
+
502
+ for i in range(n_queries):
503
+ indices = self._query_radius_single(X[i], r)
504
+ results.append(indices)
505
+
506
+ return results
507
+
508
+ def _query_radius_single(
509
+ self,
510
+ query: NDArray[np.floating],
511
+ r: float,
512
+ ) -> List[int]:
513
+ """Find all points within radius r of query point."""
514
+ indices: List[int] = []
515
+
516
+ def _search(node_id: int) -> None:
517
+ if node_id < 0:
518
+ return
519
+
520
+ centroid = self._centroids[node_id]
521
+ radius = self._radii[node_id]
522
+
523
+ # Distance to ball surface
524
+ dist_to_center = np.sqrt(np.sum((query - centroid) ** 2))
525
+
526
+ # Prune if ball is farther than radius
527
+ if dist_to_center - radius > r:
528
+ return
529
+
530
+ if self._is_leaf[node_id]:
531
+ # Check all points in leaf
532
+ leaf_indices = self._leaf_indices[node_id]
533
+ if leaf_indices is not None:
534
+ for idx in leaf_indices:
535
+ dist = np.sqrt(np.sum((query - self.data[idx]) ** 2))
536
+ if dist <= r:
537
+ indices.append(idx)
538
+ else:
539
+ # Visit both children
540
+ _search(self._left[node_id])
541
+ _search(self._right[node_id])
542
+
543
+ _search(0)
544
+ return indices
545
+
481
546
  def _query_single(
482
547
  self,
483
548
  query: NDArray[np.floating],