openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""SymbolicProblem dataclass - container for symbolic problem specification.
|
|
2
|
+
|
|
3
|
+
This module provides the SymbolicProblem dataclass that represents a trajectory
|
|
4
|
+
optimization problem in symbolic form, before lowering to executable code.
|
|
5
|
+
|
|
6
|
+
The SymbolicProblem can represent two lifecycle stages:
|
|
7
|
+
|
|
8
|
+
1. **Before preprocessing**: Raw user input with unsorted constraints
|
|
9
|
+
2. **After preprocessing**: Augmented and validated, ready for lowering
|
|
10
|
+
|
|
11
|
+
Use `is_preprocessed` to check which stage the problem is in.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
from openscvx.symbolic.constraint_set import ConstraintSet
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from openscvx.symbolic.expr import Expr
|
|
21
|
+
from openscvx.symbolic.expr.control import Control
|
|
22
|
+
from openscvx.symbolic.expr.state import State
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class SymbolicProblem:
|
|
27
|
+
"""Container for symbolic problem specification.
|
|
28
|
+
|
|
29
|
+
This dataclass holds a trajectory optimization problem in symbolic form,
|
|
30
|
+
either as raw user input or after preprocessing/augmentation. It provides
|
|
31
|
+
a typed interface for the preprocessing and lowering pipeline.
|
|
32
|
+
|
|
33
|
+
Lifecycle Stages:
|
|
34
|
+
1. **Before preprocessing**: User creates with raw dynamics, states,
|
|
35
|
+
controls, and unsorted constraints. Propagation fields are None.
|
|
36
|
+
2. **After preprocessing**: Dynamics and states are augmented (CTCS,
|
|
37
|
+
time dilation), constraints are categorized, propagation fields
|
|
38
|
+
are populated.
|
|
39
|
+
|
|
40
|
+
Use `is_preprocessed` to check whether preprocessing has completed.
|
|
41
|
+
|
|
42
|
+
Attributes:
|
|
43
|
+
dynamics: Symbolic dynamics expression (dx/dt = f(x, u)).
|
|
44
|
+
After preprocessing, includes CTCS augmented state dynamics.
|
|
45
|
+
states: List of State objects. After preprocessing, includes
|
|
46
|
+
time state and CTCS augmented states.
|
|
47
|
+
controls: List of Control objects. After preprocessing, includes
|
|
48
|
+
time dilation control.
|
|
49
|
+
constraints: ConstraintSet holding all constraints. Before preprocessing,
|
|
50
|
+
raw constraints live in `constraints.unsorted`. After preprocessing,
|
|
51
|
+
constraints are categorized into ctcs, nodal, nodal_convex, etc.
|
|
52
|
+
parameters: Dictionary mapping parameter names to numpy arrays.
|
|
53
|
+
N: Number of discretization nodes.
|
|
54
|
+
node_intervals: List of (start, end) tuples for CTCS constraint intervals.
|
|
55
|
+
Populated during preprocessing when CTCS constraints are sorted.
|
|
56
|
+
|
|
57
|
+
dynamics_prop: Propagation dynamics (may include extra states).
|
|
58
|
+
None before preprocessing, populated after.
|
|
59
|
+
states_prop: Propagation states (may include extra states).
|
|
60
|
+
None before preprocessing, populated after.
|
|
61
|
+
controls_prop: Propagation controls (typically same as controls).
|
|
62
|
+
None before preprocessing, populated after.
|
|
63
|
+
|
|
64
|
+
Example:
|
|
65
|
+
Before preprocessing::
|
|
66
|
+
|
|
67
|
+
problem = SymbolicProblem(
|
|
68
|
+
dynamics=dynamics_expr,
|
|
69
|
+
states=[x, v],
|
|
70
|
+
controls=[u],
|
|
71
|
+
constraints=ConstraintSet(unsorted=[c1, c2, c3]),
|
|
72
|
+
parameters={"mass": 1.0},
|
|
73
|
+
N=50,
|
|
74
|
+
)
|
|
75
|
+
assert not problem.is_preprocessed
|
|
76
|
+
|
|
77
|
+
After preprocessing::
|
|
78
|
+
|
|
79
|
+
processed = preprocess_symbolic_problem(problem, time=time_config)
|
|
80
|
+
assert processed.is_preprocessed
|
|
81
|
+
assert processed.constraints.is_categorized
|
|
82
|
+
# Now ready for lowering
|
|
83
|
+
lowered = lower_symbolic_problem(processed)
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
# Core problem specification
|
|
87
|
+
dynamics: "Expr"
|
|
88
|
+
states: List["State"]
|
|
89
|
+
controls: List["Control"]
|
|
90
|
+
constraints: ConstraintSet
|
|
91
|
+
parameters: Dict[str, any]
|
|
92
|
+
N: int
|
|
93
|
+
|
|
94
|
+
# CTCS node intervals (populated during preprocessing)
|
|
95
|
+
node_intervals: List[Tuple[int, int]] = field(default_factory=list)
|
|
96
|
+
|
|
97
|
+
# Propagation (None before preprocessing, populated after)
|
|
98
|
+
dynamics_prop: Optional["Expr"] = None
|
|
99
|
+
states_prop: Optional[List["State"]] = None
|
|
100
|
+
controls_prop: Optional[List["Control"]] = None
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def is_preprocessed(self) -> bool:
|
|
104
|
+
"""True if the problem has been preprocessed and is ready for lowering.
|
|
105
|
+
|
|
106
|
+
A problem is considered preprocessed when:
|
|
107
|
+
1. All constraints have been categorized (unsorted is empty)
|
|
108
|
+
2. Propagation dynamics have been set up
|
|
109
|
+
"""
|
|
110
|
+
return self.constraints.is_categorized and self.dynamics_prop is not None
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Time:
|
|
5
|
+
"""Time configuration for trajectory optimization problems.
|
|
6
|
+
|
|
7
|
+
This class encapsulates time-related parameters for trajectory optimization.
|
|
8
|
+
The time derivative is internally assumed to be 1.0.
|
|
9
|
+
|
|
10
|
+
Attributes:
|
|
11
|
+
initial (float or tuple): Initial time boundary condition.
|
|
12
|
+
Can be a float (fixed) or tuple like ("free", value), ("minimize", value),
|
|
13
|
+
or ("maximize", value).
|
|
14
|
+
final (float or tuple): Final time boundary condition.
|
|
15
|
+
Can be a float (fixed) or tuple like ("free", value), ("minimize", value),
|
|
16
|
+
or ("maximize", value).
|
|
17
|
+
min (float): Minimum bound for time variable (required).
|
|
18
|
+
max (float): Maximum bound for time variable (required).
|
|
19
|
+
|
|
20
|
+
Example:
|
|
21
|
+
```python
|
|
22
|
+
# Fixed initial and final time
|
|
23
|
+
time = Time(initial=0.0, final=10.0, min=0.0, max=20.0)
|
|
24
|
+
|
|
25
|
+
# Free final time
|
|
26
|
+
time = Time(initial=0.0, final=("free", 10.0), min=0.0, max=20.0)
|
|
27
|
+
|
|
28
|
+
# Minimize final time
|
|
29
|
+
time = Time(initial=0.0, final=("minimize", 10.0), min=0.0, max=20.0)
|
|
30
|
+
|
|
31
|
+
# Maximize initial time
|
|
32
|
+
time = Time(initial=("maximize", 0.0), final=10.0, min=0.0, max=20.0)
|
|
33
|
+
```
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
initial: Union[float, tuple],
|
|
39
|
+
final: Union[float, tuple],
|
|
40
|
+
min: float,
|
|
41
|
+
max: float,
|
|
42
|
+
):
|
|
43
|
+
"""Initialize a Time object.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
initial: Initial time boundary condition (float or tuple).
|
|
47
|
+
Tuple format: ("free", value), ("minimize", value), or ("maximize", value).
|
|
48
|
+
final: Final time boundary condition (float or tuple).
|
|
49
|
+
Tuple format: ("free", value), ("minimize", value), or ("maximize", value).
|
|
50
|
+
min: Minimum bound for time variable (required).
|
|
51
|
+
max: Maximum bound for time variable (required).
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
ValueError: If tuple format is invalid.
|
|
55
|
+
"""
|
|
56
|
+
# Validate tuple format if provided
|
|
57
|
+
for name, value in [("initial", initial), ("final", final)]:
|
|
58
|
+
if isinstance(value, tuple):
|
|
59
|
+
if len(value) != 2:
|
|
60
|
+
raise ValueError(f"{name} tuple must have exactly 2 elements: (type, value)")
|
|
61
|
+
bc_type, bc_value = value
|
|
62
|
+
if bc_type not in ["free", "minimize", "maximize"]:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"{name} boundary condition type must be 'free', "
|
|
65
|
+
f"'minimize', or 'maximize', got '{bc_type}'"
|
|
66
|
+
)
|
|
67
|
+
if not isinstance(bc_value, (int, float)):
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"{name} boundary condition value must be a number, "
|
|
70
|
+
f"got {type(bc_value).__name__}"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
self.initial = initial
|
|
74
|
+
self.final = final
|
|
75
|
+
self.min = min
|
|
76
|
+
self.max = max
|
|
77
|
+
# Time derivative is always 1.0 internally
|
|
78
|
+
self.derivative = 1.0
|
|
79
|
+
self._scaling_min = None
|
|
80
|
+
self._scaling_max = None
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def scaling_min(self):
|
|
84
|
+
"""Get the scaling minimum bound for the time variable.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Scaling minimum value, or None if not set.
|
|
88
|
+
"""
|
|
89
|
+
return self._scaling_min
|
|
90
|
+
|
|
91
|
+
@scaling_min.setter
|
|
92
|
+
def scaling_min(self, val):
|
|
93
|
+
"""Set the scaling minimum bound for the time variable.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
val: Scaling minimum value (float or None)
|
|
97
|
+
"""
|
|
98
|
+
self._scaling_min = float(val) if val is not None else None
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def scaling_max(self):
|
|
102
|
+
"""Get the scaling maximum bound for the time variable.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Scaling maximum value, or None if not set.
|
|
106
|
+
"""
|
|
107
|
+
return self._scaling_max
|
|
108
|
+
|
|
109
|
+
@scaling_max.setter
|
|
110
|
+
def scaling_max(self, val):
|
|
111
|
+
"""Set the scaling maximum bound for the time variable.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
val: Scaling maximum value (float or None)
|
|
115
|
+
"""
|
|
116
|
+
self._scaling_max = float(val) if val is not None else None
|
|
@@ -0,0 +1,420 @@
|
|
|
1
|
+
"""Unification functions for aggregating symbolic State and Control objects.
|
|
2
|
+
|
|
3
|
+
This module provides the unification layer that transforms multiple symbolic State
|
|
4
|
+
and Control objects into unified representations for numerical optimization.
|
|
5
|
+
|
|
6
|
+
The unification process:
|
|
7
|
+
1. **Collection**: Gathers all State and Control objects from expression trees
|
|
8
|
+
2. **Sorting**: Organizes variables (user-defined first, then augmented)
|
|
9
|
+
3. **Aggregation**: Concatenates bounds, guesses, and boundary conditions
|
|
10
|
+
4. **Slice Assignment**: Assigns each State/Control a slice for indexing
|
|
11
|
+
5. **Unified Representation**: Creates UnifiedState/UnifiedControl objects
|
|
12
|
+
|
|
13
|
+
This separation allows users to define problems with natural variable names
|
|
14
|
+
while maintaining efficient vectorized operations during optimization.
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
Creating and unifying multiple states::
|
|
18
|
+
|
|
19
|
+
import openscvx as ox
|
|
20
|
+
from openscvx.symbolic.unified import unify_states
|
|
21
|
+
|
|
22
|
+
# Define separate symbolic states
|
|
23
|
+
position = ox.State("position", shape=(3,), min=-10, max=10)
|
|
24
|
+
velocity = ox.State("velocity", shape=(3,), min=-5, max=5)
|
|
25
|
+
mass = ox.State("mass", shape=(1,), min=0.1, max=10.0)
|
|
26
|
+
|
|
27
|
+
# Unify into single state vector
|
|
28
|
+
unified_x = unify_states([position, velocity, mass], name="x")
|
|
29
|
+
|
|
30
|
+
# Access unified properties
|
|
31
|
+
print(unified_x.shape) # (7,) - combined shape
|
|
32
|
+
print(unified_x.min) # Combined bounds: [-10, -10, -10, -5, -5, -5, 0.1]
|
|
33
|
+
print(unified_x.true) # Access only user-defined states
|
|
34
|
+
|
|
35
|
+
Accessing slices after unification::
|
|
36
|
+
|
|
37
|
+
# After unification, each State has a slice assigned
|
|
38
|
+
print(position._slice) # slice(0, 3)
|
|
39
|
+
print(velocity._slice) # slice(3, 6)
|
|
40
|
+
print(mass._slice) # slice(6, 7)
|
|
41
|
+
|
|
42
|
+
# During lowering, these slices extract values from unified vector
|
|
43
|
+
x_unified = jnp.array([1, 2, 3, 4, 5, 6, 7])
|
|
44
|
+
position_val = x_unified[position._slice] # [1, 2, 3]
|
|
45
|
+
|
|
46
|
+
See Also:
|
|
47
|
+
- UnifiedState: Dataclass for unified state representation (in openscvx.lowered.unified)
|
|
48
|
+
- UnifiedControl: Dataclass for unified control representation (in openscvx.lowered.unified)
|
|
49
|
+
- State: Individual symbolic state variable (symbolic/expr/state.py)
|
|
50
|
+
- Control: Individual symbolic control variable (symbolic/expr/control.py)
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
from typing import List
|
|
54
|
+
|
|
55
|
+
import numpy as np
|
|
56
|
+
|
|
57
|
+
from openscvx.lowered.unified import UnifiedControl, UnifiedState
|
|
58
|
+
from openscvx.symbolic.expr.control import Control
|
|
59
|
+
from openscvx.symbolic.expr.state import State
|
|
60
|
+
|
|
61
|
+
# Re-export for backwards compatibility
|
|
62
|
+
__all__ = ["unify_states", "unify_controls", "UnifiedState", "UnifiedControl"]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def unify_states(states: List[State], name: str = "unified_state") -> UnifiedState:
|
|
66
|
+
"""Create a UnifiedState from a list of State objects.
|
|
67
|
+
|
|
68
|
+
This function is the primary way to aggregate multiple symbolic State objects into
|
|
69
|
+
a single unified state vector for numerical optimization. It:
|
|
70
|
+
|
|
71
|
+
1. Sorts states (user-defined first, augmented states second)
|
|
72
|
+
2. Concatenates all state properties (bounds, guesses, boundary conditions)
|
|
73
|
+
3. Assigns slices to each State for extracting values from unified vector
|
|
74
|
+
4. Identifies special states (time, CTCS augmented states)
|
|
75
|
+
5. Returns a UnifiedState with all aggregated data
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
states (List[State]): List of State objects to unify. Can include both
|
|
79
|
+
user-defined states and augmented states (names starting with '_').
|
|
80
|
+
name (str): Name identifier for the unified state vector (default: "unified_state")
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
UnifiedState: Unified state object containing:
|
|
84
|
+
- Aggregated bounds, guesses, and boundary conditions
|
|
85
|
+
- Shape equal to sum of all state shapes
|
|
86
|
+
- Slices for extracting individual state components
|
|
87
|
+
- Properties for accessing true vs augmented states
|
|
88
|
+
|
|
89
|
+
Example:
|
|
90
|
+
Basic unification::
|
|
91
|
+
|
|
92
|
+
import openscvx as ox
|
|
93
|
+
from openscvx.symbolic.unified import unify_states
|
|
94
|
+
|
|
95
|
+
position = ox.State("pos", shape=(3,), min=-10, max=10)
|
|
96
|
+
velocity = ox.State("vel", shape=(3,), min=-5, max=5)
|
|
97
|
+
|
|
98
|
+
unified = unify_states([position, velocity], name="x")
|
|
99
|
+
print(unified.shape) # (6,)
|
|
100
|
+
print(unified._true_dim) # 6 (all are user states)
|
|
101
|
+
print(position._slice) # slice(0, 3) - assigned during unification
|
|
102
|
+
print(velocity._slice) # slice(3, 6)
|
|
103
|
+
|
|
104
|
+
With augmented states::
|
|
105
|
+
|
|
106
|
+
# CTCS or other features may add augmented states
|
|
107
|
+
time_state = ox.State("time", shape=(1,))
|
|
108
|
+
ctcs_aug = ox.State("_ctcs_aug_0", shape=(2,)) # Augmented state
|
|
109
|
+
|
|
110
|
+
unified = unify_states([position, velocity, time_state, ctcs_aug])
|
|
111
|
+
print(unified._true_dim) # 7 (pos + vel + time)
|
|
112
|
+
print(unified.true.shape) # (7,)
|
|
113
|
+
print(unified.augmented.shape) # (2,) - only CTCS augmented
|
|
114
|
+
|
|
115
|
+
Note:
|
|
116
|
+
After unification, each State object has its `_slice` attribute set,
|
|
117
|
+
which is used during JAX lowering to extract the correct values from
|
|
118
|
+
the unified state vector.
|
|
119
|
+
|
|
120
|
+
See Also:
|
|
121
|
+
- UnifiedState: Return type with detailed documentation
|
|
122
|
+
- unify_controls(): Analogous function for Control objects
|
|
123
|
+
- State: Individual symbolic state variable
|
|
124
|
+
"""
|
|
125
|
+
if not states:
|
|
126
|
+
return UnifiedState(name=name, shape=(0,))
|
|
127
|
+
|
|
128
|
+
# Sort states: true states (not starting with '_') first, then augmented states
|
|
129
|
+
# (starting with '_')
|
|
130
|
+
true_states = [state for state in states if not state.name.startswith("_")]
|
|
131
|
+
augmented_states = [state for state in states if state.name.startswith("_")]
|
|
132
|
+
sorted_states = true_states + augmented_states
|
|
133
|
+
|
|
134
|
+
# Calculate total shape
|
|
135
|
+
total_shape = sum(state.shape[0] for state in sorted_states)
|
|
136
|
+
|
|
137
|
+
# Concatenate all arrays, handling None values properly
|
|
138
|
+
min_arrays = []
|
|
139
|
+
max_arrays = []
|
|
140
|
+
guess_arrays = []
|
|
141
|
+
initial_arrays = []
|
|
142
|
+
final_arrays = []
|
|
143
|
+
_initial_arrays = []
|
|
144
|
+
_final_arrays = []
|
|
145
|
+
initial_type_arrays = []
|
|
146
|
+
final_type_arrays = []
|
|
147
|
+
|
|
148
|
+
for state in sorted_states:
|
|
149
|
+
if state.min is not None:
|
|
150
|
+
min_arrays.append(state.min)
|
|
151
|
+
else:
|
|
152
|
+
# If min is None, fill with -inf for this state's dimensions
|
|
153
|
+
min_arrays.append(np.full(state.shape[0], -np.inf))
|
|
154
|
+
|
|
155
|
+
if state.max is not None:
|
|
156
|
+
max_arrays.append(state.max)
|
|
157
|
+
else:
|
|
158
|
+
# If max is None, fill with +inf for this state's dimensions
|
|
159
|
+
max_arrays.append(np.full(state.shape[0], np.inf))
|
|
160
|
+
|
|
161
|
+
if state.guess is not None:
|
|
162
|
+
guess_arrays.append(state.guess)
|
|
163
|
+
if state.initial is not None:
|
|
164
|
+
initial_arrays.append(state.initial)
|
|
165
|
+
if state.final is not None:
|
|
166
|
+
final_arrays.append(state.final)
|
|
167
|
+
if state._initial is not None:
|
|
168
|
+
_initial_arrays.append(state._initial)
|
|
169
|
+
if state._final is not None:
|
|
170
|
+
_final_arrays.append(state._final)
|
|
171
|
+
if state.initial_type is not None:
|
|
172
|
+
initial_type_arrays.append(state.initial_type)
|
|
173
|
+
else:
|
|
174
|
+
# If initial_type is None, fill with "Free" for this state's dimensions
|
|
175
|
+
initial_type_arrays.append(np.full(state.shape[0], "Free", dtype=object))
|
|
176
|
+
|
|
177
|
+
if state.final_type is not None:
|
|
178
|
+
final_type_arrays.append(state.final_type)
|
|
179
|
+
else:
|
|
180
|
+
# If final_type is None, fill with "Free" for this state's dimensions
|
|
181
|
+
final_type_arrays.append(np.full(state.shape[0], "Free", dtype=object))
|
|
182
|
+
|
|
183
|
+
# Concatenate arrays if they exist
|
|
184
|
+
unified_min = np.concatenate(min_arrays) if min_arrays else None
|
|
185
|
+
unified_max = np.concatenate(max_arrays) if max_arrays else None
|
|
186
|
+
unified_guess = np.concatenate(guess_arrays, axis=1) if guess_arrays else None
|
|
187
|
+
unified_initial = np.concatenate(initial_arrays) if initial_arrays else None
|
|
188
|
+
unified_final = np.concatenate(final_arrays) if final_arrays else None
|
|
189
|
+
unified__initial = np.concatenate(_initial_arrays) if _initial_arrays else None
|
|
190
|
+
unified__final = np.concatenate(_final_arrays) if _final_arrays else None
|
|
191
|
+
unified_initial_type = np.concatenate(initial_type_arrays) if initial_type_arrays else None
|
|
192
|
+
unified_final_type = np.concatenate(final_type_arrays) if final_type_arrays else None
|
|
193
|
+
|
|
194
|
+
# Calculate true dimension (only from user-defined states, not augmented ones)
|
|
195
|
+
# Since we simplified State/Control classes, all user states are "true" dimensions
|
|
196
|
+
true_dim = sum(state.shape[0] for state in true_states)
|
|
197
|
+
|
|
198
|
+
# Find time state slice
|
|
199
|
+
time_state = next((s for s in sorted_states if s.name == "time"), None)
|
|
200
|
+
time_slice = time_state._slice if time_state else None
|
|
201
|
+
|
|
202
|
+
# Find CTCS augmented states slice
|
|
203
|
+
ctcs_states = [s for s in sorted_states if s.name.startswith("_ctcs_aug_")]
|
|
204
|
+
ctcs_slice = (
|
|
205
|
+
slice(ctcs_states[0]._slice.start, ctcs_states[-1]._slice.stop) if ctcs_states else None
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Aggregate scaling_min and scaling_max from individual states
|
|
209
|
+
# Build full arrays using scaling where available, min/max otherwise
|
|
210
|
+
unified_scaling_min = None
|
|
211
|
+
unified_scaling_max = None
|
|
212
|
+
|
|
213
|
+
# Check if any state has scaling
|
|
214
|
+
has_any_scaling = any(
|
|
215
|
+
state.scaling_min is not None or state.scaling_max is not None for state in sorted_states
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
if has_any_scaling:
|
|
219
|
+
# Build full scaling arrays
|
|
220
|
+
scaling_min_list = []
|
|
221
|
+
scaling_max_list = []
|
|
222
|
+
for state in sorted_states:
|
|
223
|
+
if state.scaling_min is not None:
|
|
224
|
+
scaling_min_list.append(state.scaling_min)
|
|
225
|
+
else:
|
|
226
|
+
# Use min as fallback
|
|
227
|
+
if state.min is not None:
|
|
228
|
+
scaling_min_list.append(state.min)
|
|
229
|
+
else:
|
|
230
|
+
scaling_min_list.append(np.full(state.shape[0], -np.inf))
|
|
231
|
+
|
|
232
|
+
if state.scaling_max is not None:
|
|
233
|
+
scaling_max_list.append(state.scaling_max)
|
|
234
|
+
else:
|
|
235
|
+
# Use max as fallback
|
|
236
|
+
if state.max is not None:
|
|
237
|
+
scaling_max_list.append(state.max)
|
|
238
|
+
else:
|
|
239
|
+
scaling_max_list.append(np.full(state.shape[0], np.inf))
|
|
240
|
+
|
|
241
|
+
unified_scaling_min = np.concatenate(scaling_min_list)
|
|
242
|
+
unified_scaling_max = np.concatenate(scaling_max_list)
|
|
243
|
+
|
|
244
|
+
return UnifiedState(
|
|
245
|
+
name=name,
|
|
246
|
+
shape=(total_shape,),
|
|
247
|
+
min=unified_min,
|
|
248
|
+
max=unified_max,
|
|
249
|
+
guess=unified_guess,
|
|
250
|
+
initial=unified_initial,
|
|
251
|
+
final=unified_final,
|
|
252
|
+
_initial=unified__initial,
|
|
253
|
+
_final=unified__final,
|
|
254
|
+
initial_type=unified_initial_type,
|
|
255
|
+
final_type=unified_final_type,
|
|
256
|
+
_true_dim=true_dim,
|
|
257
|
+
_true_slice=slice(0, true_dim),
|
|
258
|
+
_augmented_slice=slice(true_dim, total_shape),
|
|
259
|
+
time_slice=time_slice,
|
|
260
|
+
ctcs_slice=ctcs_slice,
|
|
261
|
+
scaling_min=unified_scaling_min,
|
|
262
|
+
scaling_max=unified_scaling_max,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def unify_controls(controls: List[Control], name: str = "unified_control") -> UnifiedControl:
|
|
267
|
+
"""Create a UnifiedControl from a list of Control objects.
|
|
268
|
+
|
|
269
|
+
This function is the primary way to aggregate multiple symbolic Control objects into
|
|
270
|
+
a single unified control vector for numerical optimization. It:
|
|
271
|
+
|
|
272
|
+
1. Sorts controls (user-defined first, augmented controls second)
|
|
273
|
+
2. Concatenates all control properties (bounds, guesses)
|
|
274
|
+
3. Assigns slices to each Control for extracting values from unified vector
|
|
275
|
+
4. Identifies special controls (time dilation)
|
|
276
|
+
5. Returns a UnifiedControl with all aggregated data
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
controls (List[Control]): List of Control objects to unify. Can include both
|
|
280
|
+
user-defined controls and augmented controls (names starting with '_').
|
|
281
|
+
name (str): Name identifier for the unified control vector (default: "unified_control")
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
UnifiedControl: Unified control object containing:
|
|
285
|
+
- Aggregated bounds and guesses
|
|
286
|
+
- Shape equal to sum of all control shapes
|
|
287
|
+
- Slices for extracting individual control components
|
|
288
|
+
- Properties for accessing true vs augmented controls
|
|
289
|
+
|
|
290
|
+
Example:
|
|
291
|
+
Basic unification::
|
|
292
|
+
|
|
293
|
+
import openscvx as ox
|
|
294
|
+
from openscvx.symbolic.unified import unify_controls
|
|
295
|
+
|
|
296
|
+
thrust = ox.Control("thrust", shape=(3,), min=0, max=10)
|
|
297
|
+
torque = ox.Control("torque", shape=(3,), min=-1, max=1)
|
|
298
|
+
|
|
299
|
+
unified = unify_controls([thrust, torque], name="u")
|
|
300
|
+
print(unified.shape) # (6,)
|
|
301
|
+
print(unified._true_dim) # 6 (all are user controls)
|
|
302
|
+
print(thrust._slice) # slice(0, 3) - assigned during unification
|
|
303
|
+
print(torque._slice) # slice(3, 6)
|
|
304
|
+
|
|
305
|
+
With augmented controls::
|
|
306
|
+
|
|
307
|
+
# Time-optimal problems may add time dilation control
|
|
308
|
+
time_dilation = ox.Control("_time_dilation", shape=(1,))
|
|
309
|
+
|
|
310
|
+
unified = unify_controls([thrust, torque, time_dilation])
|
|
311
|
+
print(unified._true_dim) # 6 (thrust + torque)
|
|
312
|
+
print(unified.true.shape) # (6,)
|
|
313
|
+
print(unified.augmented.shape) # (1,) - time dilation
|
|
314
|
+
|
|
315
|
+
Note:
|
|
316
|
+
After unification, each Control object has its `_slice` attribute set,
|
|
317
|
+
which is used during JAX lowering to extract the correct values from
|
|
318
|
+
the unified control vector.
|
|
319
|
+
|
|
320
|
+
See Also:
|
|
321
|
+
- UnifiedControl: Return type with detailed documentation
|
|
322
|
+
- unify_states(): Analogous function for State objects
|
|
323
|
+
- Control: Individual symbolic control variable
|
|
324
|
+
"""
|
|
325
|
+
if not controls:
|
|
326
|
+
return UnifiedControl(name=name, shape=(0,))
|
|
327
|
+
|
|
328
|
+
# Sort controls: true controls (not starting with '_') first, then augmented controls
|
|
329
|
+
# (starting with '_')
|
|
330
|
+
true_controls = [control for control in controls if not control.name.startswith("_")]
|
|
331
|
+
augmented_controls = [control for control in controls if control.name.startswith("_")]
|
|
332
|
+
sorted_controls = true_controls + augmented_controls
|
|
333
|
+
|
|
334
|
+
# Calculate total shape
|
|
335
|
+
total_shape = sum(control.shape[0] for control in sorted_controls)
|
|
336
|
+
|
|
337
|
+
# Concatenate all arrays, handling None values properly
|
|
338
|
+
min_arrays = []
|
|
339
|
+
max_arrays = []
|
|
340
|
+
guess_arrays = []
|
|
341
|
+
|
|
342
|
+
for control in sorted_controls:
|
|
343
|
+
if control.min is not None:
|
|
344
|
+
min_arrays.append(control.min)
|
|
345
|
+
else:
|
|
346
|
+
# If min is None, fill with -inf for this control's dimensions
|
|
347
|
+
min_arrays.append(np.full(control.shape[0], -np.inf))
|
|
348
|
+
|
|
349
|
+
if control.max is not None:
|
|
350
|
+
max_arrays.append(control.max)
|
|
351
|
+
else:
|
|
352
|
+
# If max is None, fill with +inf for this control's dimensions
|
|
353
|
+
max_arrays.append(np.full(control.shape[0], np.inf))
|
|
354
|
+
|
|
355
|
+
if control.guess is not None:
|
|
356
|
+
guess_arrays.append(control.guess)
|
|
357
|
+
|
|
358
|
+
# Concatenate arrays if they exist
|
|
359
|
+
unified_min = np.concatenate(min_arrays) if min_arrays else None
|
|
360
|
+
unified_max = np.concatenate(max_arrays) if max_arrays else None
|
|
361
|
+
unified_guess = np.concatenate(guess_arrays, axis=1) if guess_arrays else None
|
|
362
|
+
|
|
363
|
+
# Calculate true dimension (only from user-defined controls, not augmented ones)
|
|
364
|
+
# Since we simplified State/Control classes, all user controls are "true" dimensions
|
|
365
|
+
true_dim = sum(control.shape[0] for control in true_controls)
|
|
366
|
+
|
|
367
|
+
# Find time dilation control slice
|
|
368
|
+
time_dilation_control = next((c for c in sorted_controls if c.name == "_time_dilation"), None)
|
|
369
|
+
time_dilation_slice = time_dilation_control._slice if time_dilation_control else None
|
|
370
|
+
|
|
371
|
+
# Aggregate scaling_min and scaling_max from individual controls
|
|
372
|
+
# Build full arrays using scaling where available, min/max otherwise
|
|
373
|
+
unified_scaling_min = None
|
|
374
|
+
unified_scaling_max = None
|
|
375
|
+
|
|
376
|
+
# Check if any control has scaling
|
|
377
|
+
has_any_scaling = any(
|
|
378
|
+
control.scaling_min is not None or control.scaling_max is not None
|
|
379
|
+
for control in sorted_controls
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
if has_any_scaling:
|
|
383
|
+
# Build full scaling arrays
|
|
384
|
+
scaling_min_list = []
|
|
385
|
+
scaling_max_list = []
|
|
386
|
+
for control in sorted_controls:
|
|
387
|
+
if control.scaling_min is not None:
|
|
388
|
+
scaling_min_list.append(control.scaling_min)
|
|
389
|
+
else:
|
|
390
|
+
# Use min as fallback
|
|
391
|
+
if control.min is not None:
|
|
392
|
+
scaling_min_list.append(control.min)
|
|
393
|
+
else:
|
|
394
|
+
scaling_min_list.append(np.full(control.shape[0], -np.inf))
|
|
395
|
+
|
|
396
|
+
if control.scaling_max is not None:
|
|
397
|
+
scaling_max_list.append(control.scaling_max)
|
|
398
|
+
else:
|
|
399
|
+
# Use max as fallback
|
|
400
|
+
if control.max is not None:
|
|
401
|
+
scaling_max_list.append(control.max)
|
|
402
|
+
else:
|
|
403
|
+
scaling_max_list.append(np.full(control.shape[0], np.inf))
|
|
404
|
+
|
|
405
|
+
unified_scaling_min = np.concatenate(scaling_min_list)
|
|
406
|
+
unified_scaling_max = np.concatenate(scaling_max_list)
|
|
407
|
+
|
|
408
|
+
return UnifiedControl(
|
|
409
|
+
name=name,
|
|
410
|
+
shape=(total_shape,),
|
|
411
|
+
min=unified_min,
|
|
412
|
+
max=unified_max,
|
|
413
|
+
guess=unified_guess,
|
|
414
|
+
_true_dim=true_dim,
|
|
415
|
+
_true_slice=slice(0, true_dim),
|
|
416
|
+
_augmented_slice=slice(true_dim, total_shape),
|
|
417
|
+
time_dilation_slice=time_dilation_slice,
|
|
418
|
+
scaling_min=unified_scaling_min,
|
|
419
|
+
scaling_max=unified_scaling_max,
|
|
420
|
+
)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Utility functions for caching, printing, and output formatting.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for OpenSCvx.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .utils import (
|
|
7
|
+
calculate_cost_from_boundaries,
|
|
8
|
+
gen_vertices,
|
|
9
|
+
generate_orthogonal_unit_vectors,
|
|
10
|
+
get_kp_pose,
|
|
11
|
+
rot,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"generate_orthogonal_unit_vectors",
|
|
16
|
+
"rot",
|
|
17
|
+
"gen_vertices",
|
|
18
|
+
"get_kp_pose",
|
|
19
|
+
"calculate_cost_from_boundaries",
|
|
20
|
+
]
|