nrl-tracker 1.1.3__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pytcl/core/validation.py CHANGED
@@ -472,3 +472,334 @@ def validated_array_input(
472
472
  return wrapper
473
473
 
474
474
  return decorator
475
+
476
+
477
+ class ArraySpec:
478
+ """
479
+ Specification for array validation in @validate_inputs decorator.
480
+
481
+ Parameters
482
+ ----------
483
+ dtype : type or np.dtype, optional
484
+ Required dtype.
485
+ ndim : int or tuple of int, optional
486
+ Required dimensionality.
487
+ shape : tuple, optional
488
+ Required shape (None for any size).
489
+ min_ndim : int, optional
490
+ Minimum dimensions required.
491
+ max_ndim : int, optional
492
+ Maximum dimensions allowed.
493
+ finite : bool, optional
494
+ Require all finite values.
495
+ non_negative : bool, optional
496
+ Require all values >= 0.
497
+ positive : bool, optional
498
+ Require all values > 0.
499
+ allow_empty : bool, optional
500
+ Allow empty arrays. Default True.
501
+ square : bool, optional
502
+ Require square matrix.
503
+ symmetric : bool, optional
504
+ Require symmetric matrix.
505
+ positive_definite : bool, optional
506
+ Require positive definite matrix.
507
+
508
+ Examples
509
+ --------
510
+ >>> spec = ArraySpec(ndim=2, finite=True, square=True)
511
+ >>> @validate_inputs(matrix=spec)
512
+ ... def process_matrix(matrix):
513
+ ... return np.linalg.inv(matrix)
514
+ """
515
+
516
+ def __init__(
517
+ self,
518
+ *,
519
+ dtype: type | np.dtype | None = None,
520
+ ndim: int | tuple[int, ...] | None = None,
521
+ shape: tuple[int | None, ...] | None = None,
522
+ min_ndim: int | None = None,
523
+ max_ndim: int | None = None,
524
+ finite: bool = False,
525
+ non_negative: bool = False,
526
+ positive: bool = False,
527
+ allow_empty: bool = True,
528
+ square: bool = False,
529
+ symmetric: bool = False,
530
+ positive_definite: bool = False,
531
+ ):
532
+ self.dtype = dtype
533
+ self.ndim = ndim
534
+ self.shape = shape
535
+ self.min_ndim = min_ndim
536
+ self.max_ndim = max_ndim
537
+ self.finite = finite
538
+ self.non_negative = non_negative
539
+ self.positive = positive
540
+ self.allow_empty = allow_empty
541
+ self.square = square
542
+ self.symmetric = symmetric
543
+ self.positive_definite = positive_definite
544
+
545
+ def validate(self, arr: ArrayLike, name: str) -> NDArray[Any]:
546
+ """Validate an array against this specification."""
547
+ result = validate_array(
548
+ arr,
549
+ name,
550
+ dtype=self.dtype,
551
+ ndim=self.ndim,
552
+ shape=self.shape,
553
+ min_ndim=self.min_ndim,
554
+ max_ndim=self.max_ndim,
555
+ finite=self.finite,
556
+ non_negative=self.non_negative,
557
+ positive=self.positive,
558
+ allow_empty=self.allow_empty,
559
+ )
560
+
561
+ if self.positive_definite:
562
+ result = ensure_positive_definite(result, name)
563
+ elif self.symmetric:
564
+ result = ensure_symmetric(result, name)
565
+ elif self.square:
566
+ result = ensure_square_matrix(result, name)
567
+
568
+ return result
569
+
570
+
571
+ class ScalarSpec:
572
+ """
573
+ Specification for scalar validation in @validate_inputs decorator.
574
+
575
+ Parameters
576
+ ----------
577
+ dtype : type, optional
578
+ Required type (int, float, etc.).
579
+ min_value : float, optional
580
+ Minimum allowed value (inclusive).
581
+ max_value : float, optional
582
+ Maximum allowed value (inclusive).
583
+ finite : bool, optional
584
+ Require finite value.
585
+ positive : bool, optional
586
+ Require value > 0.
587
+ non_negative : bool, optional
588
+ Require value >= 0.
589
+
590
+ Examples
591
+ --------
592
+ >>> spec = ScalarSpec(dtype=int, min_value=1, max_value=10)
593
+ >>> @validate_inputs(k=spec)
594
+ ... def get_k_nearest(k, data):
595
+ ... return data[:k]
596
+ """
597
+
598
+ def __init__(
599
+ self,
600
+ *,
601
+ dtype: type | None = None,
602
+ min_value: float | None = None,
603
+ max_value: float | None = None,
604
+ finite: bool = False,
605
+ positive: bool = False,
606
+ non_negative: bool = False,
607
+ ):
608
+ self.dtype = dtype
609
+ self.min_value = min_value
610
+ self.max_value = max_value
611
+ self.finite = finite
612
+ self.positive = positive
613
+ self.non_negative = non_negative
614
+
615
+ def validate(self, value: Any, name: str) -> Any:
616
+ """Validate a scalar value against this specification."""
617
+ # Type check
618
+ if self.dtype is not None:
619
+ if not isinstance(value, self.dtype):
620
+ try:
621
+ value = self.dtype(value)
622
+ except (ValueError, TypeError) as e:
623
+ raise ValidationError(
624
+ f"{name} must be {self.dtype.__name__}, got {type(value).__name__}"
625
+ ) from e
626
+
627
+ # Convert to float for numeric checks
628
+ try:
629
+ num_value = float(value)
630
+ except (ValueError, TypeError):
631
+ if any(
632
+ [
633
+ self.finite,
634
+ self.positive,
635
+ self.non_negative,
636
+ self.min_value is not None,
637
+ self.max_value is not None,
638
+ ]
639
+ ):
640
+ raise ValidationError(
641
+ f"{name} must be numeric for range validation"
642
+ ) from None
643
+ return value
644
+
645
+ # Finite check
646
+ if self.finite and not np.isfinite(num_value):
647
+ raise ValidationError(f"{name} must be finite, got {value}")
648
+
649
+ # Positive check
650
+ if self.positive and num_value <= 0:
651
+ raise ValidationError(f"{name} must be positive, got {value}")
652
+
653
+ # Non-negative check
654
+ if self.non_negative and num_value < 0:
655
+ raise ValidationError(f"{name} must be non-negative, got {value}")
656
+
657
+ # Range checks
658
+ if self.min_value is not None and num_value < self.min_value:
659
+ raise ValidationError(f"{name} must be >= {self.min_value}, got {value}")
660
+
661
+ if self.max_value is not None and num_value > self.max_value:
662
+ raise ValidationError(f"{name} must be <= {self.max_value}, got {value}")
663
+
664
+ return value
665
+
666
+
667
+ def validate_inputs(
668
+ **param_specs: ArraySpec | ScalarSpec | dict[str, Any],
669
+ ) -> Callable[[F], F]:
670
+ """
671
+ Decorator for validating multiple function parameters.
672
+
673
+ This decorator enables declarative input validation using specification
674
+ objects (ArraySpec, ScalarSpec) or dictionaries of validation options.
675
+
676
+ Parameters
677
+ ----------
678
+ **param_specs : ArraySpec | ScalarSpec | dict
679
+ Keyword arguments mapping parameter names to validation specs.
680
+ Each spec can be:
681
+ - ArraySpec: For array validation
682
+ - ScalarSpec: For scalar validation
683
+ - dict: Options passed to ArraySpec (for convenience)
684
+
685
+ Returns
686
+ -------
687
+ Callable
688
+ Decorated function with input validation.
689
+
690
+ Examples
691
+ --------
692
+ >>> @validate_inputs(
693
+ ... x=ArraySpec(ndim=2, finite=True),
694
+ ... P=ArraySpec(ndim=2, positive_definite=True),
695
+ ... k=ScalarSpec(dtype=int, min_value=1),
696
+ ... )
697
+ ... def kalman_update(x, P, z, H, R, k=1):
698
+ ... # x and P are guaranteed valid here
699
+ ... pass
700
+
701
+ Using dict shorthand:
702
+
703
+ >>> @validate_inputs(
704
+ ... state={"ndim": 1, "finite": True},
705
+ ... covariance={"ndim": 2, "positive_definite": True},
706
+ ... )
707
+ ... def predict(state, covariance, dt):
708
+ ... pass
709
+
710
+ Notes
711
+ -----
712
+ Validation happens in the order parameters are defined in the decorator.
713
+ If any validation fails, a ValidationError is raised with a descriptive
714
+ message identifying the parameter and the constraint violated.
715
+
716
+ See Also
717
+ --------
718
+ ArraySpec : Specification class for array validation.
719
+ ScalarSpec : Specification class for scalar validation.
720
+ validate_array : Lower-level array validation function.
721
+ """
722
+
723
+ def decorator(func: F) -> F:
724
+ import inspect
725
+
726
+ # Pre-fetch signature for efficiency
727
+ sig = inspect.signature(func)
728
+
729
+ @wraps(func)
730
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
731
+ bound = sig.bind(*args, **kwargs)
732
+ bound.apply_defaults()
733
+
734
+ for param_name, spec in param_specs.items():
735
+ if param_name not in bound.arguments:
736
+ continue
737
+
738
+ value = bound.arguments[param_name]
739
+
740
+ # Convert dict to ArraySpec
741
+ if isinstance(spec, dict):
742
+ spec = ArraySpec(**spec)
743
+
744
+ # Validate using spec
745
+ if isinstance(spec, (ArraySpec, ScalarSpec)):
746
+ bound.arguments[param_name] = spec.validate(value, param_name)
747
+ else:
748
+ raise TypeError(
749
+ f"Invalid spec type for {param_name}: {type(spec)}. "
750
+ "Use ArraySpec, ScalarSpec, or dict."
751
+ )
752
+
753
+ return func(*bound.args, **bound.kwargs)
754
+
755
+ return wrapper
756
+
757
+ return decorator
758
+
759
+
760
+ def check_compatible_shapes(
761
+ *shapes: tuple[int, ...],
762
+ names: Sequence[str] | None = None,
763
+ dimension: int | None = None,
764
+ ) -> None:
765
+ """
766
+ Check that array shapes are compatible for operations.
767
+
768
+ Parameters
769
+ ----------
770
+ *shapes : tuple of int
771
+ Shapes to check for compatibility.
772
+ names : sequence of str, optional
773
+ Names for error messages.
774
+ dimension : int, optional
775
+ If provided, only check compatibility along this dimension.
776
+
777
+ Raises
778
+ ------
779
+ ValidationError
780
+ If shapes are not compatible.
781
+
782
+ Examples
783
+ --------
784
+ >>> check_compatible_shapes((3, 4), (4, 5), names=["A", "B"], dimension=0)
785
+ # Raises: A has 3 rows but B has 4 rows
786
+
787
+ >>> check_compatible_shapes((3, 4), (4, 5), names=["A", "B"])
788
+ # Passes (inner dimensions compatible for matrix multiply)
789
+ """
790
+ if len(shapes) < 2:
791
+ return
792
+
793
+ if names is None:
794
+ names = [f"array_{i}" for i in range(len(shapes))]
795
+
796
+ if dimension is not None:
797
+ # Check specific dimension
798
+ dims = [s[dimension] if len(s) > dimension else None for s in shapes]
799
+ valid_dims = [d for d in dims if d is not None]
800
+ if valid_dims and not all(d == valid_dims[0] for d in valid_dims):
801
+ dim_strs = [f"{n}={d}" for n, d in zip(names, dims) if d is not None]
802
+ raise ValidationError(
803
+ f"Arrays have incompatible sizes along dimension {dimension}: "
804
+ f"{', '.join(dim_strs)}"
805
+ )