nrl-tracker 1.9.2__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nrl-tracker
3
- Version: 1.9.2
3
+ Version: 1.11.0
4
4
  Summary: Python port of the U.S. Naval Research Laboratory's Tracker Component Library for target tracking algorithms
5
5
  Author: Original: David F. Crouse, Naval Research Laboratory
6
6
  Maintainer: Python Port Contributors
@@ -41,6 +41,8 @@ Requires-Dist: pytest>=7.0.0; extra == "dev"
41
41
  Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
42
42
  Requires-Dist: pytest-xdist>=3.0.0; extra == "dev"
43
43
  Requires-Dist: pytest-benchmark>=4.0.0; extra == "dev"
44
+ Requires-Dist: pytest-timeout>=2.0.0; extra == "dev"
45
+ Requires-Dist: nbval>=0.10.0; extra == "dev"
44
46
  Requires-Dist: hypothesis>=6.0.0; extra == "dev"
45
47
  Requires-Dist: black>=23.0.0; extra == "dev"
46
48
  Requires-Dist: isort>=5.12.0; extra == "dev"
@@ -51,9 +53,15 @@ Requires-Dist: sphinx>=6.0.0; extra == "dev"
51
53
  Requires-Dist: sphinx-rtd-theme>=1.2.0; extra == "dev"
52
54
  Requires-Dist: myst-parser>=1.0.0; extra == "dev"
53
55
  Requires-Dist: nbsphinx>=0.9.0; extra == "dev"
56
+ Requires-Dist: jupyter>=1.0.0; extra == "dev"
57
+ Requires-Dist: ipykernel>=6.0.0; extra == "dev"
54
58
  Provides-Extra: geodesy
55
59
  Requires-Dist: pyproj>=3.4.0; extra == "geodesy"
56
60
  Requires-Dist: geographiclib>=2.0; extra == "geodesy"
61
+ Provides-Extra: gpu
62
+ Requires-Dist: cupy-cuda12x>=12.0.0; extra == "gpu"
63
+ Provides-Extra: gpu-apple
64
+ Requires-Dist: mlx>=0.5.0; extra == "gpu-apple"
57
65
  Provides-Extra: optimization
58
66
  Requires-Dist: cvxpy>=1.3.0; extra == "optimization"
59
67
  Provides-Extra: signal
@@ -63,17 +71,17 @@ Requires-Dist: plotly>=5.15.0; extra == "visualization"
63
71
 
64
72
  # Tracker Component Library (Python)
65
73
 
66
- [![PyPI version](https://img.shields.io/badge/pypi-v1.9.2-blue.svg)](https://pypi.org/project/nrl-tracker/)
74
+ [![PyPI version](https://img.shields.io/badge/pypi-v1.11.0-blue.svg)](https://pypi.org/project/nrl-tracker/)
67
75
  [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/)
68
76
  [![License: Public Domain](https://img.shields.io/badge/License-Public%20Domain-brightgreen.svg)](https://en.wikipedia.org/wiki/Public_domain)
69
77
  [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
70
- [![Tests](https://img.shields.io/badge/tests-2133%20passing-success.svg)](https://github.com/nedonatelli/TCL)
78
+ [![Tests](https://img.shields.io/badge/tests-2894%20passing-success.svg)](https://github.com/nedonatelli/TCL)
71
79
  [![MATLAB Parity](https://img.shields.io/badge/MATLAB%20Parity-100%25-brightgreen.svg)](docs/gap_analysis.rst)
72
80
  [![Type Checking](https://img.shields.io/badge/mypy--strict-passing-brightgreen.svg)](mypy.ini)
73
81
 
74
82
  A Python port of the [U.S. Naval Research Laboratory's Tracker Component Library](https://github.com/USNavalResearchLaboratory/TrackerComponentLibrary), a comprehensive collection of algorithms for target tracking, estimation, coordinate systems, and related mathematical functions.
75
83
 
76
- **1,070+ functions** | **153 modules** | **2,133 tests** | **100% MATLAB parity**
84
+ **1,070+ functions** | **153 modules** | **2,894 tests** | **100% MATLAB parity**
77
85
 
78
86
  ## Overview
79
87
 
@@ -90,6 +98,7 @@ The Tracker Component Library provides building blocks for developing target tra
90
98
  - **Navigation**: Geodetic calculations, INS mechanization, GNSS utilities, INS/GNSS integration
91
99
  - **Geophysical Models**: Gravity (WGS84, EGM96/2008), magnetism (WMM, IGRF), atmosphere, tides, terrain
92
100
  - **Signal Processing**: Digital filters, matched filtering, CFAR detection, transforms (FFT, STFT, wavelets)
101
+ - **GPU Acceleration**: CuPy (NVIDIA CUDA) and MLX (Apple Silicon) backends for batch Kalman filtering and particle filters
93
102
 
94
103
  ## Installation
95
104
 
@@ -111,6 +120,12 @@ pip install nrl-tracker[geodesy]
111
120
  # For visualization
112
121
  pip install nrl-tracker[visualization]
113
122
 
123
+ # For GPU acceleration (NVIDIA CUDA)
124
+ pip install nrl-tracker[gpu]
125
+
126
+ # For GPU acceleration (Apple Silicon M1/M2/M3)
127
+ pip install nrl-tracker[gpu-apple]
128
+
114
129
  # For development
115
130
  pip install nrl-tracker[dev]
116
131
 
@@ -183,6 +198,35 @@ assignment, total_cost = hungarian(cost_matrix)
183
198
  print(f"Optimal assignment: {assignment}, Total cost: {total_cost}")
184
199
  ```
185
200
 
201
+ ### GPU Acceleration
202
+
203
+ The library supports GPU acceleration for batch processing of multiple tracks:
204
+
205
+ ```python
206
+ from pytcl.gpu import is_gpu_available, get_backend, to_gpu, to_cpu
207
+
208
+ # Check GPU availability (auto-detects CUDA or Apple Silicon)
209
+ if is_gpu_available():
210
+ print(f"GPU available, using {get_backend()} backend")
211
+
212
+ # Transfer data to GPU
213
+ x_gpu = to_gpu(states) # (n_tracks, state_dim)
214
+ P_gpu = to_gpu(covariances) # (n_tracks, state_dim, state_dim)
215
+
216
+ # Use batch Kalman filter operations
217
+ from pytcl.gpu import batch_kf_predict
218
+ x_pred, P_pred = batch_kf_predict(x_gpu, P_gpu, F, Q)
219
+
220
+ # Transfer results back to CPU
221
+ x_pred_cpu = to_cpu(x_pred)
222
+ ```
223
+
224
+ **Supported backends:**
225
+ - **NVIDIA CUDA**: Via CuPy (`pip install nrl-tracker[gpu]`)
226
+ - **Apple Silicon**: Via MLX (`pip install nrl-tracker[gpu-apple]`)
227
+
228
+ The backend is automatically selected based on your platform.
229
+
186
230
  ## Module Structure
187
231
 
188
232
  ```
@@ -202,6 +246,7 @@ pytcl/
202
246
  ├── gravity/ # Gravity models
203
247
  ├── magnetism/ # Magnetic field models
204
248
  ├── terrain/ # Terrain elevation models
249
+ ├── gpu/ # GPU acceleration (CuPy/MLX)
205
250
  └── misc/ # Utilities, visualization
206
251
  ```
207
252
 
@@ -1,11 +1,11 @@
1
- pytcl/__init__.py,sha256=7ROietZU-pUiPRpioyeMjjZ6XybbmjRs1qRR_KoXioQ,2030
1
+ pytcl/__init__.py,sha256=5Px9PB57Sz5vZZ88WtlCY5q1z5VlW8Qjn33GLO5VitI,2032
2
2
  pytcl/logging_config.py,sha256=UJaYufQgNuIjpsOMTPo3ewz1XCHPk8a08jTHyP7uoI4,8956
3
3
  pytcl/assignment_algorithms/__init__.py,sha256=kUWhmyLhZcs5GiUQA5_v7KA3qETGsvqV6wU8r7paO-k,2976
4
4
  pytcl/assignment_algorithms/data_association.py,sha256=tsRxWJZk9aAPmE99BKXGouEpFfZrjPjb4HXvgxFUHhU,11405
5
5
  pytcl/assignment_algorithms/dijkstra_min_cost.py,sha256=z-Wk1HXRNKieBsRFqR8_UB8QvG5QkK3evazr8wzTpl0,5429
6
6
  pytcl/assignment_algorithms/gating.py,sha256=JaRaFcFqjfdsTbbTP6k_GY2zemDSR02l5yInWHpb05Y,11439
7
7
  pytcl/assignment_algorithms/jpda.py,sha256=rOY_v1vesL6EJySwD0kRDTfe7wHoDFLITg_lJLM-bX4,21731
8
- pytcl/assignment_algorithms/nd_assignment.py,sha256=qYTFZryqfnHXz1I1Kg2vTvEJIYla4JknedhKPXphdiQ,13269
8
+ pytcl/assignment_algorithms/nd_assignment.py,sha256=bcSNm3xSEjAg8gFb_TLQovpsLjNwvI5OOlh2y8XG4M0,24571
9
9
  pytcl/assignment_algorithms/network_flow.py,sha256=pPD63Z0-HOBv5XIqKUedt1KzTkcs0KG41DNojFZocDI,14459
10
10
  pytcl/assignment_algorithms/network_simplex.py,sha256=Qi10PsIYcTc6MZ-9GPl6ivaLaGA9F5-B7ltBbmasRNM,5566
11
11
  pytcl/assignment_algorithms/three_dimensional/__init__.py,sha256=1Q40OUlUQoo7YKEucwdrSNo3D4A0Zibvkr8z4TpueBg,526
@@ -46,7 +46,7 @@ pytcl/coordinate_systems/conversions/__init__.py,sha256=PkNevB78vBw0BkalydJBbQO9
46
46
  pytcl/coordinate_systems/conversions/geodetic.py,sha256=CarrTBW9rTC-CZ4E4YGxA8QjlpauuXJ2ZScnzc4QvK8,25001
47
47
  pytcl/coordinate_systems/conversions/spherical.py,sha256=GwuS1k0aUQ3AG1zZJouioMjxSIuEPRZMk-arvUCTh2k,11563
48
48
  pytcl/coordinate_systems/jacobians/__init__.py,sha256=CRGB8GzvGT_sr4Ynm51S7gSX8grqt1pO1Pq1MWmHPTs,890
49
- pytcl/coordinate_systems/jacobians/jacobians.py,sha256=0gpbelZPN4HDtvS1ymc3RIhOfxCVTKpRc-jDJXdM6pQ,11747
49
+ pytcl/coordinate_systems/jacobians/jacobians.py,sha256=IkEwyseGM1LeI2-cQEqzGD-lCplK-PVCHup7Bh3QPl4,12947
50
50
  pytcl/coordinate_systems/projections/__init__.py,sha256=TmBiffO5cmazAhsfPIVBaaqnravVSO3JxjGb0MXkucc,2404
51
51
  pytcl/coordinate_systems/projections/projections.py,sha256=y_kwcu_zp0HHiKR-wp3v3AvRcY61bleDi1SxwbrnWB0,33179
52
52
  pytcl/coordinate_systems/rotations/__init__.py,sha256=nqAz4iJd2hEOX_r7Tz4cE524sShyxdbtcQ5m56RrDLg,1047
@@ -56,7 +56,7 @@ pytcl/core/array_utils.py,sha256=SsgEiAoRCWxAVKq1aa5-nPdOi-2AB6XNObu0IaGClUk,139
56
56
  pytcl/core/constants.py,sha256=cwkCjzCU7zG2ZsFcbqwslN632v7Lw50L85s-5q892mo,9988
57
57
  pytcl/core/exceptions.py,sha256=6ImMiwL86BdmTt-Rc8fXLXxKUGQ-PcQQyxIvKKzw-n0,24324
58
58
  pytcl/core/maturity.py,sha256=Sut19NfH1-6f3Qd2QSC6OAqvDcVHJDwf5-F_-oEAMJA,11596
59
- pytcl/core/optional_deps.py,sha256=Xe7BG18SWsmzBD3zGa440U_QWKkfATBKhUfLOxhXZuU,15799
59
+ pytcl/core/optional_deps.py,sha256=a3UK_DM2s0XQE4Lwp0agq9L0qjupl_d8o4csCYbi440,16396
60
60
  pytcl/core/validation.py,sha256=4ay21cZVAil8udymwej7QnVQfNyjzi_5A8O1y-d-Lyw,23492
61
61
  pytcl/dynamic_estimation/__init__.py,sha256=zxmkZIXVfHPv5AHYpQV5nwsI0PA3m-Vw7W0gkJE7j98,5191
62
62
  pytcl/dynamic_estimation/gaussian_sum_filter.py,sha256=3Ks5-sGo3IF9p_dsIzk5u2zaXS2ZAkJFAg1mdxo8vj8,15343
@@ -70,7 +70,7 @@ pytcl/dynamic_estimation/kalman/constrained.py,sha256=Zidzz6_9OvwUyQppEltdmYTMvE
70
70
  pytcl/dynamic_estimation/kalman/extended.py,sha256=Yxc4Ve2aBtrkoelfMTFmzcXZefVZM0p0Z_a9n2IM1gQ,12032
71
71
  pytcl/dynamic_estimation/kalman/h_infinity.py,sha256=rtbYiryJbxzko-CIdNJSHuWXU2wI9T52YGBYq3o92sE,16563
72
72
  pytcl/dynamic_estimation/kalman/linear.py,sha256=gLFoCHjWtNHus_Nh4fTu67n_Xiv9QFVAuO5vO8MJICo,14673
73
- pytcl/dynamic_estimation/kalman/matrix_utils.py,sha256=couRVm0VKbhj9ctHcI-wcq8rj2MOapaSRVGuVdze3fQ,12426
73
+ pytcl/dynamic_estimation/kalman/matrix_utils.py,sha256=mcBKgYP3yl57SbyU7h92aDjytV3zQhhY6RBgm0RP-rc,14924
74
74
  pytcl/dynamic_estimation/kalman/square_root.py,sha256=RlDepNt7eJ1qbQkZElqfhcX2oJET09P9Q_P8Bv7LcJo,8199
75
75
  pytcl/dynamic_estimation/kalman/sr_ukf.py,sha256=Vys5uC58HSZSTLc9xfmWCjw_XnZZfD4MpFBXBX0OVzU,8912
76
76
  pytcl/dynamic_estimation/kalman/types.py,sha256=5sMEWAvd9kkE3EG9daYcG8uV70MBx_awC5u6KJkmiZw,2202
@@ -90,11 +90,18 @@ pytcl/dynamic_models/process_noise/__init__.py,sha256=ZRYgV40qmBkPwU3yTbIMvxorr4
90
90
  pytcl/dynamic_models/process_noise/coordinated_turn.py,sha256=0PciDXtXHjgQdaYf7qpQqIZ7qoMV4uO_kE7wjpiBe64,6483
91
91
  pytcl/dynamic_models/process_noise/polynomial.py,sha256=w5ZW5Ouw6QpVtev_mnuCmZoj6_O6ovb2L_ENKDhHYIc,7742
92
92
  pytcl/dynamic_models/process_noise/singer.py,sha256=ozAdzH4s0wGlBaxajdyZvSnK8_CumgsUZDKeMW-TxDs,5735
93
+ pytcl/gpu/__init__.py,sha256=aESvpn4Sa48xrQ4SIPb0j8uBt9bgiVHK_BgCXRLNY3o,4278
94
+ pytcl/gpu/ekf.py,sha256=KPaojhYrti9F74C71_Pgc22HKDJeBSUkyrA7Iis9-L4,12575
95
+ pytcl/gpu/kalman.py,sha256=8swMqLsnXjdl9-0vOg6wEqxtVHQRHcV4bXjHL8RwUmk,16417
96
+ pytcl/gpu/matrix_utils.py,sha256=x2SBjN6f21YUeOOKThBtmIPyBnAXhTCvWteTxJZlSs0,12601
97
+ pytcl/gpu/particle_filter.py,sha256=gqPt2ROFCkP-maFIlC8n7Td-ZNDZAN-42Ahen6TOfz8,17259
98
+ pytcl/gpu/ukf.py,sha256=83tclGEAs4LWxocvUHSk7JIoUHozQnqusxM1qk_iedk,13273
99
+ pytcl/gpu/utils.py,sha256=cedaW4evKeGCykFXI2QL_Ns8dU1yjL42MmYXf2gfGsw,14812
93
100
  pytcl/gravity/__init__.py,sha256=5xNdQSrrkt7-1-JPOYqR38CqvNJ7qKlPyMK36DGm6-I,3693
94
- pytcl/gravity/clenshaw.py,sha256=O7yYfjHMigR1RQHR_gZe3UuMIe_WsGrXFSLzn7PLfIE,16985
101
+ pytcl/gravity/clenshaw.py,sha256=zhEtIxUY6Uj8EMv7ucO3JMBqauA5shFKbUW0HO2hUfI,17240
95
102
  pytcl/gravity/egm.py,sha256=LAeNbaQ7eZakk0ciwLec0_8q41MrBFouVLpDsETis6o,19683
96
103
  pytcl/gravity/models.py,sha256=WqBwaOhQdGMx7MsYGYYNbwQLj8rgV-I_VhKZLFvmfso,11990
97
- pytcl/gravity/spherical_harmonics.py,sha256=bRUFVLgPQEJ8M5a_cJrJ-d5s5xTCmOs4fwRvdYaACuw,18522
104
+ pytcl/gravity/spherical_harmonics.py,sha256=SbCIlfNuJBwQ1nIJKo0DzgeEfW7RD_QnyKI0VhDSiGQ,18686
98
105
  pytcl/gravity/tides.py,sha256=NjsiXSiI7f-0qGr7G7YJVpIOVGzDxagz2S2vf_aRq68,28681
99
106
  pytcl/magnetism/__init__.py,sha256=pBASOzCPHNnYqUH_XDEblhGtjz50vY9uW2KS25A0zQQ,2701
100
107
  pytcl/magnetism/emm.py,sha256=iIdxSL0uGGIf8nfA-c_SmHvg9_J7HwRA2-qbQIUW6IE,22380
@@ -165,8 +172,8 @@ pytcl/trackers/mht.py,sha256=osEOXMaCeTt1eVn_E08dLRhEvBroVmf8b81zO5Zp1lU,20720
165
172
  pytcl/trackers/multi_target.py,sha256=RDITa0xnbgtVYAMj5XXp4lljo5lZ2zAAc02KZlOjxbQ,10526
166
173
  pytcl/trackers/single_target.py,sha256=Yy3FwaNTArMWcaod-0HVeiioNV4xLWxNDn_7ZPVqQYs,6562
167
174
  pytcl/transponders/__init__.py,sha256=5fL4u3lKCYgPLo5uFeuZbtRZkJPABntuKYGUvVgMMEI,41
168
- nrl_tracker-1.9.2.dist-info/LICENSE,sha256=rB5G4WppIIUzMOYr2N6uyYlNJ00hRJqE5tie6BMvYuE,1612
169
- nrl_tracker-1.9.2.dist-info/METADATA,sha256=HPSMMmbYsxCQpku97YqiIiaTut7fd0LOxLUBKCoMV-Y,12452
170
- nrl_tracker-1.9.2.dist-info/WHEEL,sha256=pL8R0wFFS65tNSRnaOVrsw9EOkOqxLrlUPenUYnJKNo,91
171
- nrl_tracker-1.9.2.dist-info/top_level.txt,sha256=17megxcrTPBWwPZTh6jTkwTKxX7No-ZqRpyvElnnO-s,6
172
- nrl_tracker-1.9.2.dist-info/RECORD,,
175
+ nrl_tracker-1.11.0.dist-info/LICENSE,sha256=rB5G4WppIIUzMOYr2N6uyYlNJ00hRJqE5tie6BMvYuE,1612
176
+ nrl_tracker-1.11.0.dist-info/METADATA,sha256=XU3LUdmSB3WwEn-r_0iaov-Ve80tFzJrbPHTibngc88,14038
177
+ nrl_tracker-1.11.0.dist-info/WHEEL,sha256=pL8R0wFFS65tNSRnaOVrsw9EOkOqxLrlUPenUYnJKNo,91
178
+ nrl_tracker-1.11.0.dist-info/top_level.txt,sha256=17megxcrTPBWwPZTh6jTkwTKxX7No-ZqRpyvElnnO-s,6
179
+ nrl_tracker-1.11.0.dist-info/RECORD,,
pytcl/__init__.py CHANGED
@@ -6,8 +6,8 @@ systems, dynamic models, estimation algorithms, and mathematical functions.
6
6
 
7
7
  This is a Python port of the U.S. Naval Research Laboratory's Tracker Component
8
8
  Library originally written in MATLAB.
9
- **Current Version:** 1.9.2 (January 4, 2026)
10
- **Status:** Production-ready, 2,133 tests passing, 76% line coverage
9
+ **Current Version:** 1.11.0 (January 5, 2026)
10
+ **Status:** Production-ready, 2,894 tests passing, 76% line coverage
11
11
  Examples
12
12
  --------
13
13
  >>> import pytcl as pytcl
@@ -21,7 +21,7 @@ References
21
21
  no. 5, pp. 18-27, May 2017.
22
22
  """
23
23
 
24
- __version__ = "1.9.2"
24
+ __version__ = "1.11.0"
25
25
  __author__ = "Python Port Contributors"
26
26
  __original_author__ = "David F. Crouse, Naval Research Laboratory"
27
27
 
@@ -9,6 +9,11 @@ enabling more complex assignment scenarios such as:
9
9
  The module provides a unified interface for solving high-dimensional
10
10
  assignment problems using generalized relaxation methods.
11
11
 
12
+ Performance Notes
13
+ -----------------
14
+ For sparse cost tensors (mostly invalid assignments), use SparseCostTensor
15
+ to reduce memory usage by up to 50% and improve performance on large problems.
16
+
12
17
  References
13
18
  ----------
14
19
  .. [1] Poore, A. B., "Multidimensional Assignment Problem and Data
@@ -18,7 +23,7 @@ References
18
23
  Drug Discovery," Perspectives in Drug Discovery and Design, 2003.
19
24
  """
20
25
 
21
- from typing import NamedTuple, Optional, Tuple
26
+ from typing import List, NamedTuple, Optional, Tuple, Union
22
27
 
23
28
  import numpy as np
24
29
  from numpy.typing import NDArray
@@ -442,3 +447,356 @@ def detect_dimension_conflicts(
442
447
  return True
443
448
 
444
449
  return False
450
+
451
+
452
+ class SparseCostTensor:
453
+ """
454
+ Sparse representation of N-dimensional cost tensor.
455
+
456
+ For assignment problems where most entries represent invalid
457
+ assignments (infinite cost), storing only valid entries reduces
458
+ memory by 50% or more and speeds up greedy algorithms.
459
+
460
+ Attributes
461
+ ----------
462
+ dims : tuple
463
+ Shape of the full tensor (n1, n2, ..., nk).
464
+ indices : ndarray
465
+ Array of shape (n_valid, n_dims) with valid entry indices.
466
+ costs : ndarray
467
+ Array of shape (n_valid,) with costs for valid entries.
468
+ default_cost : float
469
+ Cost for entries not explicitly stored (default: inf).
470
+
471
+ Examples
472
+ --------
473
+ >>> import numpy as np
474
+ >>> # Create sparse tensor for 10x10x10 problem with 50 valid entries
475
+ >>> dims = (10, 10, 10)
476
+ >>> valid_indices = np.random.randint(0, 10, size=(50, 3))
477
+ >>> valid_costs = np.random.rand(50)
478
+ >>> sparse = SparseCostTensor(dims, valid_indices, valid_costs)
479
+ >>> sparse.n_valid
480
+ 50
481
+ >>> sparse.sparsity # Fraction of valid entries
482
+ 0.05
483
+
484
+ >>> # Convert from dense tensor with inf for invalid
485
+ >>> dense = np.full((5, 5, 5), np.inf)
486
+ >>> dense[0, 0, 0] = 1.0
487
+ >>> dense[1, 1, 1] = 2.0
488
+ >>> sparse = SparseCostTensor.from_dense(dense)
489
+ >>> sparse.n_valid
490
+ 2
491
+ """
492
+
493
+ def __init__(
494
+ self,
495
+ dims: Tuple[int, ...],
496
+ indices: NDArray[np.intp],
497
+ costs: NDArray[np.float64],
498
+ default_cost: float = np.inf,
499
+ ):
500
+ """
501
+ Initialize sparse cost tensor.
502
+
503
+ Parameters
504
+ ----------
505
+ dims : tuple
506
+ Shape of the full tensor.
507
+ indices : ndarray
508
+ Valid entry indices, shape (n_valid, n_dims).
509
+ costs : ndarray
510
+ Costs for valid entries, shape (n_valid,).
511
+ default_cost : float
512
+ Cost for invalid (unstored) entries.
513
+ """
514
+ self.dims = dims
515
+ self.indices = np.asarray(indices, dtype=np.intp)
516
+ self.costs = np.asarray(costs, dtype=np.float64)
517
+ self.default_cost = default_cost
518
+
519
+ # Build lookup for O(1) cost retrieval
520
+ self._cost_map: dict[Tuple[int, ...], float] = {}
521
+ for i in range(len(self.costs)):
522
+ key = tuple(self.indices[i])
523
+ self._cost_map[key] = self.costs[i]
524
+
525
+ @property
526
+ def n_dims(self) -> int:
527
+ """Number of dimensions."""
528
+ return len(self.dims)
529
+
530
+ @property
531
+ def n_valid(self) -> int:
532
+ """Number of valid (finite cost) entries."""
533
+ return len(self.costs)
534
+
535
+ @property
536
+ def sparsity(self) -> float:
537
+ """Fraction of tensor that is valid (0 to 1)."""
538
+ total_size = int(np.prod(self.dims))
539
+ return self.n_valid / total_size if total_size > 0 else 0.0
540
+
541
+ @property
542
+ def memory_savings(self) -> float:
543
+ """Estimated memory savings vs dense representation (0 to 1)."""
544
+ dense_size = np.prod(self.dims) * 8 # 8 bytes per float64
545
+ sparse_size = self.n_valid * (8 + self.n_dims * 8) # cost + indices
546
+ return max(0, 1 - sparse_size / dense_size) if dense_size > 0 else 0.0
547
+
548
+ def get_cost(self, index: Tuple[int, ...]) -> float:
549
+ """Get cost for a specific index tuple."""
550
+ return self._cost_map.get(index, self.default_cost)
551
+
552
+ def to_dense(self) -> NDArray[np.float64]:
553
+ """
554
+ Convert to dense tensor representation.
555
+
556
+ Returns
557
+ -------
558
+ dense : ndarray
559
+ Full tensor with default_cost for unstored entries.
560
+
561
+ Notes
562
+ -----
563
+ May use significant memory for large tensors.
564
+ """
565
+ dense = np.full(self.dims, self.default_cost, dtype=np.float64)
566
+ for i in range(len(self.costs)):
567
+ dense[tuple(self.indices[i])] = self.costs[i]
568
+ return dense
569
+
570
+ @classmethod
571
+ def from_dense(
572
+ cls,
573
+ dense: NDArray[np.float64],
574
+ threshold: float = 1e10,
575
+ ) -> "SparseCostTensor":
576
+ """
577
+ Create sparse tensor from dense array.
578
+
579
+ Parameters
580
+ ----------
581
+ dense : ndarray
582
+ Dense cost tensor.
583
+ threshold : float
584
+ Entries above this value are considered invalid.
585
+ Default 1e10 (catches np.inf and large values).
586
+
587
+ Returns
588
+ -------
589
+ SparseCostTensor
590
+ Sparse representation.
591
+
592
+ Examples
593
+ --------
594
+ >>> import numpy as np
595
+ >>> dense = np.array([[[1, np.inf], [np.inf, 2]],
596
+ ... [[np.inf, 3], [4, np.inf]]])
597
+ >>> sparse = SparseCostTensor.from_dense(dense)
598
+ >>> sparse.n_valid
599
+ 4
600
+ """
601
+ valid_mask = dense < threshold
602
+ indices = np.array(np.where(valid_mask)).T
603
+ costs = dense[valid_mask]
604
+ return cls(dense.shape, indices, costs, default_cost=np.inf)
605
+
606
+
607
+ def greedy_assignment_nd_sparse(
608
+ sparse_cost: SparseCostTensor,
609
+ max_assignments: Optional[int] = None,
610
+ ) -> AssignmentNDResult:
611
+ """
612
+ Greedy solver for sparse N-dimensional assignment.
613
+
614
+ Selects minimum-cost tuples from valid entries only, which is much
615
+ faster than dense greedy when sparsity < 0.5.
616
+
617
+ Parameters
618
+ ----------
619
+ sparse_cost : SparseCostTensor
620
+ Sparse cost tensor with valid entries only.
621
+ max_assignments : int, optional
622
+ Maximum number of assignments (default: min(dimensions)).
623
+
624
+ Returns
625
+ -------
626
+ AssignmentNDResult
627
+ Assignments, total cost, and algorithm info.
628
+
629
+ Examples
630
+ --------
631
+ >>> import numpy as np
632
+ >>> # Create sparse problem
633
+ >>> dims = (10, 10, 10)
634
+ >>> # Only 20 valid assignments out of 1000
635
+ >>> indices = np.array([[i, i, i] for i in range(10)] +
636
+ ... [[i, (i+1)%10, (i+2)%10] for i in range(10)])
637
+ >>> costs = np.random.rand(20)
638
+ >>> sparse = SparseCostTensor(dims, indices, costs)
639
+ >>> result = greedy_assignment_nd_sparse(sparse)
640
+ >>> result.converged
641
+ True
642
+
643
+ Notes
644
+ -----
645
+ Time complexity is O(n_valid * log(n_valid)) vs O(total_size * log(total_size))
646
+ for dense greedy. For a 10x10x10 tensor with 50 valid entries, this is
647
+ 50*log(50) vs 1000*log(1000), about 20x faster.
648
+ """
649
+ dims = sparse_cost.dims
650
+ n_dims = sparse_cost.n_dims
651
+
652
+ if max_assignments is None:
653
+ max_assignments = min(dims)
654
+
655
+ # Sort valid entries by cost
656
+ sorted_indices = np.argsort(sparse_cost.costs)
657
+
658
+ assignments: List[Tuple[int, ...]] = []
659
+ used_indices: List[set[int]] = [set() for _ in range(n_dims)]
660
+ total_cost = 0.0
661
+
662
+ for sorted_idx in sorted_indices:
663
+ if len(assignments) >= max_assignments:
664
+ break
665
+
666
+ multi_idx = tuple(sparse_cost.indices[sorted_idx])
667
+
668
+ # Check if any dimension index is already used
669
+ conflict = False
670
+ for d, idx in enumerate(multi_idx):
671
+ if idx in used_indices[d]:
672
+ conflict = True
673
+ break
674
+
675
+ if not conflict:
676
+ assignments.append(multi_idx)
677
+ total_cost += sparse_cost.costs[sorted_idx]
678
+ for d, idx in enumerate(multi_idx):
679
+ used_indices[d].add(idx)
680
+
681
+ assignments_array = np.array(assignments, dtype=np.intp)
682
+ if assignments_array.size == 0:
683
+ assignments_array = np.empty((0, n_dims), dtype=np.intp)
684
+
685
+ return AssignmentNDResult(
686
+ assignments=assignments_array,
687
+ cost=total_cost,
688
+ converged=True,
689
+ n_iterations=1,
690
+ gap=0.0,
691
+ )
692
+
693
+
694
+ def assignment_nd(
695
+ cost: Union[NDArray[np.float64], SparseCostTensor],
696
+ method: str = "auto",
697
+ max_assignments: Optional[int] = None,
698
+ max_iterations: int = 100,
699
+ tolerance: float = 1e-6,
700
+ epsilon: float = 0.01,
701
+ verbose: bool = False,
702
+ ) -> AssignmentNDResult:
703
+ """
704
+ Unified interface for N-dimensional assignment.
705
+
706
+ Automatically selects between dense and sparse algorithms based on
707
+ input type and sparsity.
708
+
709
+ Parameters
710
+ ----------
711
+ cost : ndarray or SparseCostTensor
712
+ Cost tensor (dense) or sparse cost representation.
713
+ method : str
714
+ Algorithm to use: 'auto', 'greedy', 'relaxation', 'auction'.
715
+ 'auto' selects greedy for sparse, relaxation for dense.
716
+ max_assignments : int, optional
717
+ Maximum number of assignments for greedy methods.
718
+ max_iterations : int
719
+ Maximum iterations for iterative methods.
720
+ tolerance : float
721
+ Convergence tolerance for relaxation.
722
+ epsilon : float
723
+ Price increment for auction algorithm.
724
+ verbose : bool
725
+ Print progress information.
726
+
727
+ Returns
728
+ -------
729
+ AssignmentNDResult
730
+ Assignment solution.
731
+
732
+ Examples
733
+ --------
734
+ >>> import numpy as np
735
+ >>> # Dense usage
736
+ >>> cost = np.random.rand(4, 4, 4)
737
+ >>> result = assignment_nd(cost, method='greedy')
738
+ >>> result.converged
739
+ True
740
+
741
+ >>> # Sparse usage (more efficient for large sparse problems)
742
+ >>> dense = np.full((20, 20, 20), np.inf)
743
+ >>> for i in range(20):
744
+ ... dense[i, i, i] = np.random.rand()
745
+ >>> sparse = SparseCostTensor.from_dense(dense)
746
+ >>> result = assignment_nd(sparse, method='auto')
747
+ >>> result.converged
748
+ True
749
+
750
+ See Also
751
+ --------
752
+ greedy_assignment_nd : Dense greedy algorithm.
753
+ greedy_assignment_nd_sparse : Sparse greedy algorithm.
754
+ relaxation_assignment_nd : Lagrangian relaxation.
755
+ auction_assignment_nd : Auction algorithm.
756
+ """
757
+ if isinstance(cost, SparseCostTensor):
758
+ # Sparse input - use sparse algorithm
759
+ if method in ("auto", "greedy"):
760
+ return greedy_assignment_nd_sparse(cost, max_assignments)
761
+ else:
762
+ # Convert to dense for other methods
763
+ dense = cost.to_dense()
764
+ if method == "relaxation":
765
+ return relaxation_assignment_nd(
766
+ dense, max_iterations, tolerance, verbose
767
+ )
768
+ elif method == "auction":
769
+ return auction_assignment_nd(
770
+ dense, max_iterations, epsilon=epsilon, verbose=verbose
771
+ )
772
+ else:
773
+ raise ValueError(f"Unknown method: {method}")
774
+ else:
775
+ # Dense input
776
+ cost = np.asarray(cost, dtype=np.float64)
777
+ if method == "auto":
778
+ # Use relaxation for better solutions on dense
779
+ return relaxation_assignment_nd(cost, max_iterations, tolerance, verbose)
780
+ elif method == "greedy":
781
+ return greedy_assignment_nd(cost, max_assignments)
782
+ elif method == "relaxation":
783
+ return relaxation_assignment_nd(cost, max_iterations, tolerance, verbose)
784
+ elif method == "auction":
785
+ return auction_assignment_nd(
786
+ cost, max_iterations, epsilon=epsilon, verbose=verbose
787
+ )
788
+ else:
789
+ raise ValueError(f"Unknown method: {method}")
790
+
791
+
792
+ __all__ = [
793
+ "AssignmentNDResult",
794
+ "SparseCostTensor",
795
+ "validate_cost_tensor",
796
+ "greedy_assignment_nd",
797
+ "greedy_assignment_nd_sparse",
798
+ "relaxation_assignment_nd",
799
+ "auction_assignment_nd",
800
+ "detect_dimension_conflicts",
801
+ "assignment_nd",
802
+ ]