nrl-tracker 1.1.3__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nrl_tracker-1.1.3.dist-info → nrl_tracker-1.2.0.dist-info}/METADATA +1 -1
- {nrl_tracker-1.1.3.dist-info → nrl_tracker-1.2.0.dist-info}/RECORD +19 -18
- pytcl/__init__.py +1 -1
- pytcl/astronomical/reference_frames.py +127 -55
- pytcl/containers/__init__.py +24 -0
- pytcl/containers/base.py +219 -0
- pytcl/containers/covertree.py +21 -26
- pytcl/containers/kd_tree.py +94 -29
- pytcl/containers/vptree.py +17 -26
- pytcl/core/__init__.py +18 -0
- pytcl/core/validation.py +331 -0
- pytcl/gravity/egm.py +13 -0
- pytcl/gravity/spherical_harmonics.py +97 -36
- pytcl/mathematical_functions/special_functions/hypergeometric.py +79 -15
- pytcl/navigation/geodesy.py +245 -159
- pytcl/navigation/great_circle.py +98 -16
- {nrl_tracker-1.1.3.dist-info → nrl_tracker-1.2.0.dist-info}/LICENSE +0 -0
- {nrl_tracker-1.1.3.dist-info → nrl_tracker-1.2.0.dist-info}/WHEEL +0 -0
- {nrl_tracker-1.1.3.dist-info → nrl_tracker-1.2.0.dist-info}/top_level.txt +0 -0
pytcl/core/validation.py
CHANGED
|
@@ -472,3 +472,334 @@ def validated_array_input(
|
|
|
472
472
|
return wrapper
|
|
473
473
|
|
|
474
474
|
return decorator
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
class ArraySpec:
|
|
478
|
+
"""
|
|
479
|
+
Specification for array validation in @validate_inputs decorator.
|
|
480
|
+
|
|
481
|
+
Parameters
|
|
482
|
+
----------
|
|
483
|
+
dtype : type or np.dtype, optional
|
|
484
|
+
Required dtype.
|
|
485
|
+
ndim : int or tuple of int, optional
|
|
486
|
+
Required dimensionality.
|
|
487
|
+
shape : tuple, optional
|
|
488
|
+
Required shape (None for any size).
|
|
489
|
+
min_ndim : int, optional
|
|
490
|
+
Minimum dimensions required.
|
|
491
|
+
max_ndim : int, optional
|
|
492
|
+
Maximum dimensions allowed.
|
|
493
|
+
finite : bool, optional
|
|
494
|
+
Require all finite values.
|
|
495
|
+
non_negative : bool, optional
|
|
496
|
+
Require all values >= 0.
|
|
497
|
+
positive : bool, optional
|
|
498
|
+
Require all values > 0.
|
|
499
|
+
allow_empty : bool, optional
|
|
500
|
+
Allow empty arrays. Default True.
|
|
501
|
+
square : bool, optional
|
|
502
|
+
Require square matrix.
|
|
503
|
+
symmetric : bool, optional
|
|
504
|
+
Require symmetric matrix.
|
|
505
|
+
positive_definite : bool, optional
|
|
506
|
+
Require positive definite matrix.
|
|
507
|
+
|
|
508
|
+
Examples
|
|
509
|
+
--------
|
|
510
|
+
>>> spec = ArraySpec(ndim=2, finite=True, square=True)
|
|
511
|
+
>>> @validate_inputs(matrix=spec)
|
|
512
|
+
... def process_matrix(matrix):
|
|
513
|
+
... return np.linalg.inv(matrix)
|
|
514
|
+
"""
|
|
515
|
+
|
|
516
|
+
def __init__(
|
|
517
|
+
self,
|
|
518
|
+
*,
|
|
519
|
+
dtype: type | np.dtype | None = None,
|
|
520
|
+
ndim: int | tuple[int, ...] | None = None,
|
|
521
|
+
shape: tuple[int | None, ...] | None = None,
|
|
522
|
+
min_ndim: int | None = None,
|
|
523
|
+
max_ndim: int | None = None,
|
|
524
|
+
finite: bool = False,
|
|
525
|
+
non_negative: bool = False,
|
|
526
|
+
positive: bool = False,
|
|
527
|
+
allow_empty: bool = True,
|
|
528
|
+
square: bool = False,
|
|
529
|
+
symmetric: bool = False,
|
|
530
|
+
positive_definite: bool = False,
|
|
531
|
+
):
|
|
532
|
+
self.dtype = dtype
|
|
533
|
+
self.ndim = ndim
|
|
534
|
+
self.shape = shape
|
|
535
|
+
self.min_ndim = min_ndim
|
|
536
|
+
self.max_ndim = max_ndim
|
|
537
|
+
self.finite = finite
|
|
538
|
+
self.non_negative = non_negative
|
|
539
|
+
self.positive = positive
|
|
540
|
+
self.allow_empty = allow_empty
|
|
541
|
+
self.square = square
|
|
542
|
+
self.symmetric = symmetric
|
|
543
|
+
self.positive_definite = positive_definite
|
|
544
|
+
|
|
545
|
+
def validate(self, arr: ArrayLike, name: str) -> NDArray[Any]:
|
|
546
|
+
"""Validate an array against this specification."""
|
|
547
|
+
result = validate_array(
|
|
548
|
+
arr,
|
|
549
|
+
name,
|
|
550
|
+
dtype=self.dtype,
|
|
551
|
+
ndim=self.ndim,
|
|
552
|
+
shape=self.shape,
|
|
553
|
+
min_ndim=self.min_ndim,
|
|
554
|
+
max_ndim=self.max_ndim,
|
|
555
|
+
finite=self.finite,
|
|
556
|
+
non_negative=self.non_negative,
|
|
557
|
+
positive=self.positive,
|
|
558
|
+
allow_empty=self.allow_empty,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
if self.positive_definite:
|
|
562
|
+
result = ensure_positive_definite(result, name)
|
|
563
|
+
elif self.symmetric:
|
|
564
|
+
result = ensure_symmetric(result, name)
|
|
565
|
+
elif self.square:
|
|
566
|
+
result = ensure_square_matrix(result, name)
|
|
567
|
+
|
|
568
|
+
return result
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
class ScalarSpec:
|
|
572
|
+
"""
|
|
573
|
+
Specification for scalar validation in @validate_inputs decorator.
|
|
574
|
+
|
|
575
|
+
Parameters
|
|
576
|
+
----------
|
|
577
|
+
dtype : type, optional
|
|
578
|
+
Required type (int, float, etc.).
|
|
579
|
+
min_value : float, optional
|
|
580
|
+
Minimum allowed value (inclusive).
|
|
581
|
+
max_value : float, optional
|
|
582
|
+
Maximum allowed value (inclusive).
|
|
583
|
+
finite : bool, optional
|
|
584
|
+
Require finite value.
|
|
585
|
+
positive : bool, optional
|
|
586
|
+
Require value > 0.
|
|
587
|
+
non_negative : bool, optional
|
|
588
|
+
Require value >= 0.
|
|
589
|
+
|
|
590
|
+
Examples
|
|
591
|
+
--------
|
|
592
|
+
>>> spec = ScalarSpec(dtype=int, min_value=1, max_value=10)
|
|
593
|
+
>>> @validate_inputs(k=spec)
|
|
594
|
+
... def get_k_nearest(k, data):
|
|
595
|
+
... return data[:k]
|
|
596
|
+
"""
|
|
597
|
+
|
|
598
|
+
def __init__(
|
|
599
|
+
self,
|
|
600
|
+
*,
|
|
601
|
+
dtype: type | None = None,
|
|
602
|
+
min_value: float | None = None,
|
|
603
|
+
max_value: float | None = None,
|
|
604
|
+
finite: bool = False,
|
|
605
|
+
positive: bool = False,
|
|
606
|
+
non_negative: bool = False,
|
|
607
|
+
):
|
|
608
|
+
self.dtype = dtype
|
|
609
|
+
self.min_value = min_value
|
|
610
|
+
self.max_value = max_value
|
|
611
|
+
self.finite = finite
|
|
612
|
+
self.positive = positive
|
|
613
|
+
self.non_negative = non_negative
|
|
614
|
+
|
|
615
|
+
def validate(self, value: Any, name: str) -> Any:
|
|
616
|
+
"""Validate a scalar value against this specification."""
|
|
617
|
+
# Type check
|
|
618
|
+
if self.dtype is not None:
|
|
619
|
+
if not isinstance(value, self.dtype):
|
|
620
|
+
try:
|
|
621
|
+
value = self.dtype(value)
|
|
622
|
+
except (ValueError, TypeError) as e:
|
|
623
|
+
raise ValidationError(
|
|
624
|
+
f"{name} must be {self.dtype.__name__}, got {type(value).__name__}"
|
|
625
|
+
) from e
|
|
626
|
+
|
|
627
|
+
# Convert to float for numeric checks
|
|
628
|
+
try:
|
|
629
|
+
num_value = float(value)
|
|
630
|
+
except (ValueError, TypeError):
|
|
631
|
+
if any(
|
|
632
|
+
[
|
|
633
|
+
self.finite,
|
|
634
|
+
self.positive,
|
|
635
|
+
self.non_negative,
|
|
636
|
+
self.min_value is not None,
|
|
637
|
+
self.max_value is not None,
|
|
638
|
+
]
|
|
639
|
+
):
|
|
640
|
+
raise ValidationError(
|
|
641
|
+
f"{name} must be numeric for range validation"
|
|
642
|
+
) from None
|
|
643
|
+
return value
|
|
644
|
+
|
|
645
|
+
# Finite check
|
|
646
|
+
if self.finite and not np.isfinite(num_value):
|
|
647
|
+
raise ValidationError(f"{name} must be finite, got {value}")
|
|
648
|
+
|
|
649
|
+
# Positive check
|
|
650
|
+
if self.positive and num_value <= 0:
|
|
651
|
+
raise ValidationError(f"{name} must be positive, got {value}")
|
|
652
|
+
|
|
653
|
+
# Non-negative check
|
|
654
|
+
if self.non_negative and num_value < 0:
|
|
655
|
+
raise ValidationError(f"{name} must be non-negative, got {value}")
|
|
656
|
+
|
|
657
|
+
# Range checks
|
|
658
|
+
if self.min_value is not None and num_value < self.min_value:
|
|
659
|
+
raise ValidationError(f"{name} must be >= {self.min_value}, got {value}")
|
|
660
|
+
|
|
661
|
+
if self.max_value is not None and num_value > self.max_value:
|
|
662
|
+
raise ValidationError(f"{name} must be <= {self.max_value}, got {value}")
|
|
663
|
+
|
|
664
|
+
return value
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def validate_inputs(
|
|
668
|
+
**param_specs: ArraySpec | ScalarSpec | dict[str, Any],
|
|
669
|
+
) -> Callable[[F], F]:
|
|
670
|
+
"""
|
|
671
|
+
Decorator for validating multiple function parameters.
|
|
672
|
+
|
|
673
|
+
This decorator enables declarative input validation using specification
|
|
674
|
+
objects (ArraySpec, ScalarSpec) or dictionaries of validation options.
|
|
675
|
+
|
|
676
|
+
Parameters
|
|
677
|
+
----------
|
|
678
|
+
**param_specs : ArraySpec | ScalarSpec | dict
|
|
679
|
+
Keyword arguments mapping parameter names to validation specs.
|
|
680
|
+
Each spec can be:
|
|
681
|
+
- ArraySpec: For array validation
|
|
682
|
+
- ScalarSpec: For scalar validation
|
|
683
|
+
- dict: Options passed to ArraySpec (for convenience)
|
|
684
|
+
|
|
685
|
+
Returns
|
|
686
|
+
-------
|
|
687
|
+
Callable
|
|
688
|
+
Decorated function with input validation.
|
|
689
|
+
|
|
690
|
+
Examples
|
|
691
|
+
--------
|
|
692
|
+
>>> @validate_inputs(
|
|
693
|
+
... x=ArraySpec(ndim=2, finite=True),
|
|
694
|
+
... P=ArraySpec(ndim=2, positive_definite=True),
|
|
695
|
+
... k=ScalarSpec(dtype=int, min_value=1),
|
|
696
|
+
... )
|
|
697
|
+
... def kalman_update(x, P, z, H, R, k=1):
|
|
698
|
+
... # x and P are guaranteed valid here
|
|
699
|
+
... pass
|
|
700
|
+
|
|
701
|
+
Using dict shorthand:
|
|
702
|
+
|
|
703
|
+
>>> @validate_inputs(
|
|
704
|
+
... state={"ndim": 1, "finite": True},
|
|
705
|
+
... covariance={"ndim": 2, "positive_definite": True},
|
|
706
|
+
... )
|
|
707
|
+
... def predict(state, covariance, dt):
|
|
708
|
+
... pass
|
|
709
|
+
|
|
710
|
+
Notes
|
|
711
|
+
-----
|
|
712
|
+
Validation happens in the order parameters are defined in the decorator.
|
|
713
|
+
If any validation fails, a ValidationError is raised with a descriptive
|
|
714
|
+
message identifying the parameter and the constraint violated.
|
|
715
|
+
|
|
716
|
+
See Also
|
|
717
|
+
--------
|
|
718
|
+
ArraySpec : Specification class for array validation.
|
|
719
|
+
ScalarSpec : Specification class for scalar validation.
|
|
720
|
+
validate_array : Lower-level array validation function.
|
|
721
|
+
"""
|
|
722
|
+
|
|
723
|
+
def decorator(func: F) -> F:
|
|
724
|
+
import inspect
|
|
725
|
+
|
|
726
|
+
# Pre-fetch signature for efficiency
|
|
727
|
+
sig = inspect.signature(func)
|
|
728
|
+
|
|
729
|
+
@wraps(func)
|
|
730
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
731
|
+
bound = sig.bind(*args, **kwargs)
|
|
732
|
+
bound.apply_defaults()
|
|
733
|
+
|
|
734
|
+
for param_name, spec in param_specs.items():
|
|
735
|
+
if param_name not in bound.arguments:
|
|
736
|
+
continue
|
|
737
|
+
|
|
738
|
+
value = bound.arguments[param_name]
|
|
739
|
+
|
|
740
|
+
# Convert dict to ArraySpec
|
|
741
|
+
if isinstance(spec, dict):
|
|
742
|
+
spec = ArraySpec(**spec)
|
|
743
|
+
|
|
744
|
+
# Validate using spec
|
|
745
|
+
if isinstance(spec, (ArraySpec, ScalarSpec)):
|
|
746
|
+
bound.arguments[param_name] = spec.validate(value, param_name)
|
|
747
|
+
else:
|
|
748
|
+
raise TypeError(
|
|
749
|
+
f"Invalid spec type for {param_name}: {type(spec)}. "
|
|
750
|
+
"Use ArraySpec, ScalarSpec, or dict."
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
return func(*bound.args, **bound.kwargs)
|
|
754
|
+
|
|
755
|
+
return wrapper
|
|
756
|
+
|
|
757
|
+
return decorator
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
def check_compatible_shapes(
|
|
761
|
+
*shapes: tuple[int, ...],
|
|
762
|
+
names: Sequence[str] | None = None,
|
|
763
|
+
dimension: int | None = None,
|
|
764
|
+
) -> None:
|
|
765
|
+
"""
|
|
766
|
+
Check that array shapes are compatible for operations.
|
|
767
|
+
|
|
768
|
+
Parameters
|
|
769
|
+
----------
|
|
770
|
+
*shapes : tuple of int
|
|
771
|
+
Shapes to check for compatibility.
|
|
772
|
+
names : sequence of str, optional
|
|
773
|
+
Names for error messages.
|
|
774
|
+
dimension : int, optional
|
|
775
|
+
If provided, only check compatibility along this dimension.
|
|
776
|
+
|
|
777
|
+
Raises
|
|
778
|
+
------
|
|
779
|
+
ValidationError
|
|
780
|
+
If shapes are not compatible.
|
|
781
|
+
|
|
782
|
+
Examples
|
|
783
|
+
--------
|
|
784
|
+
>>> check_compatible_shapes((3, 4), (4, 5), names=["A", "B"], dimension=0)
|
|
785
|
+
# Raises: A has 3 rows but B has 4 rows
|
|
786
|
+
|
|
787
|
+
>>> check_compatible_shapes((3, 4), (4, 5), names=["A", "B"])
|
|
788
|
+
# Passes (inner dimensions compatible for matrix multiply)
|
|
789
|
+
"""
|
|
790
|
+
if len(shapes) < 2:
|
|
791
|
+
return
|
|
792
|
+
|
|
793
|
+
if names is None:
|
|
794
|
+
names = [f"array_{i}" for i in range(len(shapes))]
|
|
795
|
+
|
|
796
|
+
if dimension is not None:
|
|
797
|
+
# Check specific dimension
|
|
798
|
+
dims = [s[dimension] if len(s) > dimension else None for s in shapes]
|
|
799
|
+
valid_dims = [d for d in dims if d is not None]
|
|
800
|
+
if valid_dims and not all(d == valid_dims[0] for d in valid_dims):
|
|
801
|
+
dim_strs = [f"{n}={d}" for n, d in zip(names, dims) if d is not None]
|
|
802
|
+
raise ValidationError(
|
|
803
|
+
f"Arrays have incompatible sizes along dimension {dimension}: "
|
|
804
|
+
f"{', '.join(dim_strs)}"
|
|
805
|
+
)
|
pytcl/gravity/egm.py
CHANGED
|
@@ -21,6 +21,7 @@ References
|
|
|
21
21
|
https://earth-info.nga.mil/
|
|
22
22
|
"""
|
|
23
23
|
|
|
24
|
+
import logging
|
|
24
25
|
import os
|
|
25
26
|
from functools import lru_cache
|
|
26
27
|
from pathlib import Path
|
|
@@ -32,6 +33,9 @@ from numpy.typing import NDArray
|
|
|
32
33
|
from .clenshaw import clenshaw_gravity, clenshaw_potential
|
|
33
34
|
from .models import WGS84, normal_gravity_somigliana
|
|
34
35
|
|
|
36
|
+
# Module logger
|
|
37
|
+
_logger = logging.getLogger("pytcl.gravity.egm")
|
|
38
|
+
|
|
35
39
|
|
|
36
40
|
class EGMCoefficients(NamedTuple):
|
|
37
41
|
"""Earth Gravitational Model coefficients.
|
|
@@ -317,6 +321,8 @@ def _load_coefficients_cached(
|
|
|
317
321
|
data_dir = get_data_dir()
|
|
318
322
|
filepath = data_dir / f"{model}.cof"
|
|
319
323
|
|
|
324
|
+
_logger.debug("Loading %s coefficients from %s", model, filepath)
|
|
325
|
+
|
|
320
326
|
if not filepath.exists():
|
|
321
327
|
raise FileNotFoundError(
|
|
322
328
|
f"Coefficient file not found: {filepath}\n"
|
|
@@ -330,6 +336,13 @@ def _load_coefficients_cached(
|
|
|
330
336
|
actual_n_max = n_max if n_max is not None else int(params["n_max_full"])
|
|
331
337
|
C, S = parse_egm_file(filepath, actual_n_max)
|
|
332
338
|
|
|
339
|
+
_logger.info(
|
|
340
|
+
"Loaded %s coefficients: n_max=%d, array_size=%.1f MB",
|
|
341
|
+
model,
|
|
342
|
+
C.shape[0] - 1,
|
|
343
|
+
C.nbytes / 1024 / 1024 * 2, # Both C and S arrays
|
|
344
|
+
)
|
|
345
|
+
|
|
333
346
|
return EGMCoefficients(
|
|
334
347
|
C=C,
|
|
335
348
|
S=S,
|
|
@@ -11,11 +11,69 @@ References
|
|
|
11
11
|
.. [2] O. Montenbruck and E. Gill, "Satellite Orbits," Springer, 2000.
|
|
12
12
|
"""
|
|
13
13
|
|
|
14
|
+
import logging
|
|
15
|
+
from functools import lru_cache
|
|
14
16
|
from typing import Optional, Tuple
|
|
15
17
|
|
|
16
18
|
import numpy as np
|
|
17
19
|
from numpy.typing import NDArray
|
|
18
20
|
|
|
21
|
+
# Module logger
|
|
22
|
+
_logger = logging.getLogger("pytcl.gravity.spherical_harmonics")
|
|
23
|
+
|
|
24
|
+
# Cache configuration for Legendre polynomials
|
|
25
|
+
_LEGENDRE_CACHE_DECIMALS = 8 # Precision for x quantization
|
|
26
|
+
_LEGENDRE_CACHE_MAXSIZE = 64 # Max cached (n_max, m_max, x) combinations
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _quantize_x(x: float) -> float:
|
|
30
|
+
"""Quantize x value for cache key compatibility."""
|
|
31
|
+
return round(x, _LEGENDRE_CACHE_DECIMALS)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@lru_cache(maxsize=_LEGENDRE_CACHE_MAXSIZE)
|
|
35
|
+
def _associated_legendre_cached(
|
|
36
|
+
n_max: int,
|
|
37
|
+
m_max: int,
|
|
38
|
+
x_quantized: float,
|
|
39
|
+
normalized: bool,
|
|
40
|
+
) -> tuple:
|
|
41
|
+
"""Cached Legendre polynomial computation (internal).
|
|
42
|
+
|
|
43
|
+
Returns tuple of tuples for hashability.
|
|
44
|
+
"""
|
|
45
|
+
P = np.zeros((n_max + 1, m_max + 1))
|
|
46
|
+
u = np.sqrt(1 - x_quantized * x_quantized)
|
|
47
|
+
|
|
48
|
+
P[0, 0] = 1.0
|
|
49
|
+
|
|
50
|
+
for m in range(1, m_max + 1):
|
|
51
|
+
if normalized:
|
|
52
|
+
P[m, m] = u * np.sqrt((2 * m + 1) / (2 * m)) * P[m - 1, m - 1]
|
|
53
|
+
else:
|
|
54
|
+
P[m, m] = (2 * m - 1) * u * P[m - 1, m - 1]
|
|
55
|
+
|
|
56
|
+
for m in range(m_max):
|
|
57
|
+
if m + 1 <= n_max:
|
|
58
|
+
if normalized:
|
|
59
|
+
P[m + 1, m] = x_quantized * np.sqrt(2 * m + 3) * P[m, m]
|
|
60
|
+
else:
|
|
61
|
+
P[m + 1, m] = x_quantized * (2 * m + 1) * P[m, m]
|
|
62
|
+
|
|
63
|
+
for m in range(m_max + 1):
|
|
64
|
+
for n in range(m + 2, n_max + 1):
|
|
65
|
+
if normalized:
|
|
66
|
+
a_nm = np.sqrt((4 * n * n - 1) / (n * n - m * m))
|
|
67
|
+
b_nm = np.sqrt(((n - 1) ** 2 - m * m) / (4 * (n - 1) ** 2 - 1))
|
|
68
|
+
P[n, m] = a_nm * (x_quantized * P[n - 1, m] - b_nm * P[n - 2, m])
|
|
69
|
+
else:
|
|
70
|
+
P[n, m] = (
|
|
71
|
+
(2 * n - 1) * x_quantized * P[n - 1, m] - (n + m - 1) * P[n - 2, m]
|
|
72
|
+
) / (n - m)
|
|
73
|
+
|
|
74
|
+
# Convert to tuple of tuples for hashability
|
|
75
|
+
return tuple(tuple(row) for row in P)
|
|
76
|
+
|
|
19
77
|
|
|
20
78
|
def associated_legendre(
|
|
21
79
|
n_max: int,
|
|
@@ -53,6 +111,9 @@ def associated_legendre(
|
|
|
53
111
|
|
|
54
112
|
\\int_{-1}^{1} [\\bar{P}_n^m(x)]^2 dx = \\frac{2}{2n+1}
|
|
55
113
|
|
|
114
|
+
Results are cached for repeated queries with the same parameters.
|
|
115
|
+
Cache key quantizes x to 8 decimal places (~1e-8 precision).
|
|
116
|
+
|
|
56
117
|
Examples
|
|
57
118
|
--------
|
|
58
119
|
>>> P = associated_legendre(2, 2, 0.5)
|
|
@@ -63,42 +124,10 @@ def associated_legendre(
|
|
|
63
124
|
if not -1 <= x <= 1:
|
|
64
125
|
raise ValueError("x must be in [-1, 1]")
|
|
65
126
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
# Seed values
|
|
72
|
-
P[0, 0] = 1.0
|
|
73
|
-
|
|
74
|
-
# Sectoral recursion: P_m^m from P_{m-1}^{m-1}
|
|
75
|
-
for m in range(1, m_max + 1):
|
|
76
|
-
if normalized:
|
|
77
|
-
P[m, m] = u * np.sqrt((2 * m + 1) / (2 * m)) * P[m - 1, m - 1]
|
|
78
|
-
else:
|
|
79
|
-
P[m, m] = (2 * m - 1) * u * P[m - 1, m - 1]
|
|
80
|
-
|
|
81
|
-
# Compute P_{m+1}^m from P_m^m
|
|
82
|
-
for m in range(m_max):
|
|
83
|
-
if m + 1 <= n_max:
|
|
84
|
-
if normalized:
|
|
85
|
-
P[m + 1, m] = x * np.sqrt(2 * m + 3) * P[m, m]
|
|
86
|
-
else:
|
|
87
|
-
P[m + 1, m] = x * (2 * m + 1) * P[m, m]
|
|
88
|
-
|
|
89
|
-
# General recursion: P_n^m from P_{n-1}^m and P_{n-2}^m
|
|
90
|
-
for m in range(m_max + 1):
|
|
91
|
-
for n in range(m + 2, n_max + 1):
|
|
92
|
-
if normalized:
|
|
93
|
-
a_nm = np.sqrt((4 * n * n - 1) / (n * n - m * m))
|
|
94
|
-
b_nm = np.sqrt(((n - 1) ** 2 - m * m) / (4 * (n - 1) ** 2 - 1))
|
|
95
|
-
P[n, m] = a_nm * (x * P[n - 1, m] - b_nm * P[n - 2, m])
|
|
96
|
-
else:
|
|
97
|
-
P[n, m] = (
|
|
98
|
-
(2 * n - 1) * x * P[n - 1, m] - (n + m - 1) * P[n - 2, m]
|
|
99
|
-
) / (n - m)
|
|
100
|
-
|
|
101
|
-
return P
|
|
127
|
+
# Use cached computation
|
|
128
|
+
x_q = _quantize_x(x)
|
|
129
|
+
cached = _associated_legendre_cached(n_max, m_max, x_q, normalized)
|
|
130
|
+
return np.array(cached)
|
|
102
131
|
|
|
103
132
|
|
|
104
133
|
def associated_legendre_derivative(
|
|
@@ -230,6 +259,14 @@ def spherical_harmonic_sum(
|
|
|
230
259
|
if n_max is None:
|
|
231
260
|
n_max = C.shape[0] - 1
|
|
232
261
|
|
|
262
|
+
_logger.debug(
|
|
263
|
+
"spherical_harmonic_sum: lat=%.4f, lon=%.4f, r=%.1f, n_max=%d",
|
|
264
|
+
lat,
|
|
265
|
+
lon,
|
|
266
|
+
r,
|
|
267
|
+
n_max,
|
|
268
|
+
)
|
|
269
|
+
|
|
233
270
|
# Colatitude for Legendre polynomials
|
|
234
271
|
colat = np.pi / 2 - lat
|
|
235
272
|
cos_colat = np.cos(colat)
|
|
@@ -495,6 +532,28 @@ def associated_legendre_scaled(
|
|
|
495
532
|
return P_scaled, scale_exp
|
|
496
533
|
|
|
497
534
|
|
|
535
|
+
def clear_legendre_cache() -> None:
|
|
536
|
+
"""Clear cached Legendre polynomial results.
|
|
537
|
+
|
|
538
|
+
Call this function to clear the cached associated Legendre
|
|
539
|
+
polynomial arrays. Useful when memory is constrained or after
|
|
540
|
+
processing a batch with different colatitude values.
|
|
541
|
+
"""
|
|
542
|
+
_associated_legendre_cached.cache_clear()
|
|
543
|
+
_logger.debug("Legendre polynomial cache cleared")
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
def get_legendre_cache_info():
|
|
547
|
+
"""Get cache statistics for Legendre polynomials.
|
|
548
|
+
|
|
549
|
+
Returns
|
|
550
|
+
-------
|
|
551
|
+
CacheInfo
|
|
552
|
+
Named tuple with hits, misses, maxsize, currsize.
|
|
553
|
+
"""
|
|
554
|
+
return _associated_legendre_cached.cache_info()
|
|
555
|
+
|
|
556
|
+
|
|
498
557
|
__all__ = [
|
|
499
558
|
"associated_legendre",
|
|
500
559
|
"associated_legendre_derivative",
|
|
@@ -502,4 +561,6 @@ __all__ = [
|
|
|
502
561
|
"gravity_acceleration",
|
|
503
562
|
"legendre_scaling_factors",
|
|
504
563
|
"associated_legendre_scaled",
|
|
564
|
+
"clear_legendre_cache",
|
|
565
|
+
"get_legendre_cache_info",
|
|
505
566
|
]
|
|
@@ -3,13 +3,84 @@ Hypergeometric functions.
|
|
|
3
3
|
|
|
4
4
|
This module provides hypergeometric functions commonly used in
|
|
5
5
|
mathematical physics, probability theory, and special function evaluation.
|
|
6
|
+
|
|
7
|
+
Performance
|
|
8
|
+
-----------
|
|
9
|
+
The generalized hypergeometric function uses Numba JIT compilation for
|
|
10
|
+
the series summation loop, providing significant speedup for the general
|
|
11
|
+
case (p > 2 or q > 1).
|
|
6
12
|
"""
|
|
7
13
|
|
|
8
14
|
import numpy as np
|
|
9
15
|
import scipy.special as sp
|
|
16
|
+
from numba import njit
|
|
10
17
|
from numpy.typing import ArrayLike, NDArray
|
|
11
18
|
|
|
12
19
|
|
|
20
|
+
@njit(cache=True, fastmath=True)
|
|
21
|
+
def _hypergeometric_series(
|
|
22
|
+
a: np.ndarray,
|
|
23
|
+
b: np.ndarray,
|
|
24
|
+
z: np.ndarray,
|
|
25
|
+
max_terms: int,
|
|
26
|
+
tol: float,
|
|
27
|
+
) -> np.ndarray:
|
|
28
|
+
"""
|
|
29
|
+
Numba-optimized series summation for generalized hypergeometric function.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
a : ndarray
|
|
34
|
+
Numerator parameters (1D array).
|
|
35
|
+
b : ndarray
|
|
36
|
+
Denominator parameters (1D array).
|
|
37
|
+
z : ndarray
|
|
38
|
+
Argument values (1D array).
|
|
39
|
+
max_terms : int
|
|
40
|
+
Maximum number of series terms.
|
|
41
|
+
tol : float
|
|
42
|
+
Convergence tolerance.
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
result : ndarray
|
|
47
|
+
Computed pFq values for each z.
|
|
48
|
+
"""
|
|
49
|
+
n_z = len(z)
|
|
50
|
+
p = len(a)
|
|
51
|
+
q = len(b)
|
|
52
|
+
|
|
53
|
+
result = np.ones(n_z, dtype=np.float64)
|
|
54
|
+
term = np.ones(n_z, dtype=np.float64)
|
|
55
|
+
|
|
56
|
+
for k in range(1, max_terms):
|
|
57
|
+
# Compute numerator product: prod(a_i + k - 1)
|
|
58
|
+
num_factor = 1.0
|
|
59
|
+
for i in range(p):
|
|
60
|
+
num_factor *= a[i] + k - 1
|
|
61
|
+
|
|
62
|
+
# Compute denominator product: prod(b_i + k - 1) * k
|
|
63
|
+
den_factor = float(k)
|
|
64
|
+
for i in range(q):
|
|
65
|
+
den_factor *= b[i] + k - 1
|
|
66
|
+
|
|
67
|
+
# Update term and result for each z value
|
|
68
|
+
ratio = num_factor / den_factor
|
|
69
|
+
converged = True
|
|
70
|
+
for j in range(n_z):
|
|
71
|
+
term[j] = term[j] * z[j] * ratio
|
|
72
|
+
result[j] += term[j]
|
|
73
|
+
|
|
74
|
+
# Check convergence
|
|
75
|
+
if np.abs(term[j]) >= tol * np.abs(result[j]):
|
|
76
|
+
converged = False
|
|
77
|
+
|
|
78
|
+
if converged:
|
|
79
|
+
break
|
|
80
|
+
|
|
81
|
+
return result
|
|
82
|
+
|
|
83
|
+
|
|
13
84
|
def hyp0f1(
|
|
14
85
|
b: ArrayLike,
|
|
15
86
|
z: ArrayLike,
|
|
@@ -369,6 +440,11 @@ def generalized_hypergeometric(
|
|
|
369
440
|
- p = q + 1: |z| < 1
|
|
370
441
|
- p > q + 1: diverges except for polynomial cases
|
|
371
442
|
|
|
443
|
+
Performance
|
|
444
|
+
-----------
|
|
445
|
+
Uses Numba JIT compilation for the general case (p > 2 or q > 1),
|
|
446
|
+
providing 5-10x speedup over pure Python loops.
|
|
447
|
+
|
|
372
448
|
Examples
|
|
373
449
|
--------
|
|
374
450
|
>>> generalized_hypergeometric([1], [2], 1) # 1F1(1; 2; 1) ~ 1.718...
|
|
@@ -389,21 +465,9 @@ def generalized_hypergeometric(
|
|
|
389
465
|
elif p == 2 and q == 1:
|
|
390
466
|
return hyp2f1(a[0], a[1], b[0], z)
|
|
391
467
|
|
|
392
|
-
# General case: series summation
|
|
393
|
-
|
|
394
|
-
result =
|
|
395
|
-
term = np.ones_like(z, dtype=np.float64)
|
|
396
|
-
|
|
397
|
-
for k in range(1, max_terms):
|
|
398
|
-
# Compute ratio term_k / term_{k-1}
|
|
399
|
-
num_factor = np.prod(a + k - 1)
|
|
400
|
-
den_factor = np.prod(b + k - 1) * k
|
|
401
|
-
term = term * z * num_factor / den_factor
|
|
402
|
-
|
|
403
|
-
result += term
|
|
404
|
-
|
|
405
|
-
if np.all(np.abs(term) < tol * np.abs(result)):
|
|
406
|
-
break
|
|
468
|
+
# General case: use Numba-optimized series summation
|
|
469
|
+
z_arr = np.atleast_1d(z)
|
|
470
|
+
result = _hypergeometric_series(a, b, z_arr, max_terms, tol)
|
|
407
471
|
|
|
408
472
|
return result if result.size > 1 else result[0]
|
|
409
473
|
|