openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,110 @@
1
+ """SymbolicProblem dataclass - container for symbolic problem specification.
2
+
3
+ This module provides the SymbolicProblem dataclass that represents a trajectory
4
+ optimization problem in symbolic form, before lowering to executable code.
5
+
6
+ The SymbolicProblem can represent two lifecycle stages:
7
+
8
+ 1. **Before preprocessing**: Raw user input with unsorted constraints
9
+ 2. **After preprocessing**: Augmented and validated, ready for lowering
10
+
11
+ Use `is_preprocessed` to check which stage the problem is in.
12
+ """
13
+
14
+ from dataclasses import dataclass, field
15
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
16
+
17
+ from openscvx.symbolic.constraint_set import ConstraintSet
18
+
19
+ if TYPE_CHECKING:
20
+ from openscvx.symbolic.expr import Expr
21
+ from openscvx.symbolic.expr.control import Control
22
+ from openscvx.symbolic.expr.state import State
23
+
24
+
25
+ @dataclass
26
+ class SymbolicProblem:
27
+ """Container for symbolic problem specification.
28
+
29
+ This dataclass holds a trajectory optimization problem in symbolic form,
30
+ either as raw user input or after preprocessing/augmentation. It provides
31
+ a typed interface for the preprocessing and lowering pipeline.
32
+
33
+ Lifecycle Stages:
34
+ 1. **Before preprocessing**: User creates with raw dynamics, states,
35
+ controls, and unsorted constraints. Propagation fields are None.
36
+ 2. **After preprocessing**: Dynamics and states are augmented (CTCS,
37
+ time dilation), constraints are categorized, propagation fields
38
+ are populated.
39
+
40
+ Use `is_preprocessed` to check whether preprocessing has completed.
41
+
42
+ Attributes:
43
+ dynamics: Symbolic dynamics expression (dx/dt = f(x, u)).
44
+ After preprocessing, includes CTCS augmented state dynamics.
45
+ states: List of State objects. After preprocessing, includes
46
+ time state and CTCS augmented states.
47
+ controls: List of Control objects. After preprocessing, includes
48
+ time dilation control.
49
+ constraints: ConstraintSet holding all constraints. Before preprocessing,
50
+ raw constraints live in `constraints.unsorted`. After preprocessing,
51
+ constraints are categorized into ctcs, nodal, nodal_convex, etc.
52
+ parameters: Dictionary mapping parameter names to numpy arrays.
53
+ N: Number of discretization nodes.
54
+ node_intervals: List of (start, end) tuples for CTCS constraint intervals.
55
+ Populated during preprocessing when CTCS constraints are sorted.
56
+
57
+ dynamics_prop: Propagation dynamics (may include extra states).
58
+ None before preprocessing, populated after.
59
+ states_prop: Propagation states (may include extra states).
60
+ None before preprocessing, populated after.
61
+ controls_prop: Propagation controls (typically same as controls).
62
+ None before preprocessing, populated after.
63
+
64
+ Example:
65
+ Before preprocessing::
66
+
67
+ problem = SymbolicProblem(
68
+ dynamics=dynamics_expr,
69
+ states=[x, v],
70
+ controls=[u],
71
+ constraints=ConstraintSet(unsorted=[c1, c2, c3]),
72
+ parameters={"mass": 1.0},
73
+ N=50,
74
+ )
75
+ assert not problem.is_preprocessed
76
+
77
+ After preprocessing::
78
+
79
+ processed = preprocess_symbolic_problem(problem, time=time_config)
80
+ assert processed.is_preprocessed
81
+ assert processed.constraints.is_categorized
82
+ # Now ready for lowering
83
+ lowered = lower_symbolic_problem(processed)
84
+ """
85
+
86
+ # Core problem specification
87
+ dynamics: "Expr"
88
+ states: List["State"]
89
+ controls: List["Control"]
90
+ constraints: ConstraintSet
91
+ parameters: Dict[str, any]
92
+ N: int
93
+
94
+ # CTCS node intervals (populated during preprocessing)
95
+ node_intervals: List[Tuple[int, int]] = field(default_factory=list)
96
+
97
+ # Propagation (None before preprocessing, populated after)
98
+ dynamics_prop: Optional["Expr"] = None
99
+ states_prop: Optional[List["State"]] = None
100
+ controls_prop: Optional[List["Control"]] = None
101
+
102
+ @property
103
+ def is_preprocessed(self) -> bool:
104
+ """True if the problem has been preprocessed and is ready for lowering.
105
+
106
+ A problem is considered preprocessed when:
107
+ 1. All constraints have been categorized (unsorted is empty)
108
+ 2. Propagation dynamics have been set up
109
+ """
110
+ return self.constraints.is_categorized and self.dynamics_prop is not None
@@ -0,0 +1,116 @@
1
+ from typing import Union
2
+
3
+
4
+ class Time:
5
+ """Time configuration for trajectory optimization problems.
6
+
7
+ This class encapsulates time-related parameters for trajectory optimization.
8
+ The time derivative is internally assumed to be 1.0.
9
+
10
+ Attributes:
11
+ initial (float or tuple): Initial time boundary condition.
12
+ Can be a float (fixed) or tuple like ("free", value), ("minimize", value),
13
+ or ("maximize", value).
14
+ final (float or tuple): Final time boundary condition.
15
+ Can be a float (fixed) or tuple like ("free", value), ("minimize", value),
16
+ or ("maximize", value).
17
+ min (float): Minimum bound for time variable (required).
18
+ max (float): Maximum bound for time variable (required).
19
+
20
+ Example:
21
+ ```python
22
+ # Fixed initial and final time
23
+ time = Time(initial=0.0, final=10.0, min=0.0, max=20.0)
24
+
25
+ # Free final time
26
+ time = Time(initial=0.0, final=("free", 10.0), min=0.0, max=20.0)
27
+
28
+ # Minimize final time
29
+ time = Time(initial=0.0, final=("minimize", 10.0), min=0.0, max=20.0)
30
+
31
+ # Maximize initial time
32
+ time = Time(initial=("maximize", 0.0), final=10.0, min=0.0, max=20.0)
33
+ ```
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ initial: Union[float, tuple],
39
+ final: Union[float, tuple],
40
+ min: float,
41
+ max: float,
42
+ ):
43
+ """Initialize a Time object.
44
+
45
+ Args:
46
+ initial: Initial time boundary condition (float or tuple).
47
+ Tuple format: ("free", value), ("minimize", value), or ("maximize", value).
48
+ final: Final time boundary condition (float or tuple).
49
+ Tuple format: ("free", value), ("minimize", value), or ("maximize", value).
50
+ min: Minimum bound for time variable (required).
51
+ max: Maximum bound for time variable (required).
52
+
53
+ Raises:
54
+ ValueError: If tuple format is invalid.
55
+ """
56
+ # Validate tuple format if provided
57
+ for name, value in [("initial", initial), ("final", final)]:
58
+ if isinstance(value, tuple):
59
+ if len(value) != 2:
60
+ raise ValueError(f"{name} tuple must have exactly 2 elements: (type, value)")
61
+ bc_type, bc_value = value
62
+ if bc_type not in ["free", "minimize", "maximize"]:
63
+ raise ValueError(
64
+ f"{name} boundary condition type must be 'free', "
65
+ f"'minimize', or 'maximize', got '{bc_type}'"
66
+ )
67
+ if not isinstance(bc_value, (int, float)):
68
+ raise ValueError(
69
+ f"{name} boundary condition value must be a number, "
70
+ f"got {type(bc_value).__name__}"
71
+ )
72
+
73
+ self.initial = initial
74
+ self.final = final
75
+ self.min = min
76
+ self.max = max
77
+ # Time derivative is always 1.0 internally
78
+ self.derivative = 1.0
79
+ self._scaling_min = None
80
+ self._scaling_max = None
81
+
82
+ @property
83
+ def scaling_min(self):
84
+ """Get the scaling minimum bound for the time variable.
85
+
86
+ Returns:
87
+ Scaling minimum value, or None if not set.
88
+ """
89
+ return self._scaling_min
90
+
91
+ @scaling_min.setter
92
+ def scaling_min(self, val):
93
+ """Set the scaling minimum bound for the time variable.
94
+
95
+ Args:
96
+ val: Scaling minimum value (float or None)
97
+ """
98
+ self._scaling_min = float(val) if val is not None else None
99
+
100
+ @property
101
+ def scaling_max(self):
102
+ """Get the scaling maximum bound for the time variable.
103
+
104
+ Returns:
105
+ Scaling maximum value, or None if not set.
106
+ """
107
+ return self._scaling_max
108
+
109
+ @scaling_max.setter
110
+ def scaling_max(self, val):
111
+ """Set the scaling maximum bound for the time variable.
112
+
113
+ Args:
114
+ val: Scaling maximum value (float or None)
115
+ """
116
+ self._scaling_max = float(val) if val is not None else None
@@ -0,0 +1,420 @@
1
+ """Unification functions for aggregating symbolic State and Control objects.
2
+
3
+ This module provides the unification layer that transforms multiple symbolic State
4
+ and Control objects into unified representations for numerical optimization.
5
+
6
+ The unification process:
7
+ 1. **Collection**: Gathers all State and Control objects from expression trees
8
+ 2. **Sorting**: Organizes variables (user-defined first, then augmented)
9
+ 3. **Aggregation**: Concatenates bounds, guesses, and boundary conditions
10
+ 4. **Slice Assignment**: Assigns each State/Control a slice for indexing
11
+ 5. **Unified Representation**: Creates UnifiedState/UnifiedControl objects
12
+
13
+ This separation allows users to define problems with natural variable names
14
+ while maintaining efficient vectorized operations during optimization.
15
+
16
+ Example:
17
+ Creating and unifying multiple states::
18
+
19
+ import openscvx as ox
20
+ from openscvx.symbolic.unified import unify_states
21
+
22
+ # Define separate symbolic states
23
+ position = ox.State("position", shape=(3,), min=-10, max=10)
24
+ velocity = ox.State("velocity", shape=(3,), min=-5, max=5)
25
+ mass = ox.State("mass", shape=(1,), min=0.1, max=10.0)
26
+
27
+ # Unify into single state vector
28
+ unified_x = unify_states([position, velocity, mass], name="x")
29
+
30
+ # Access unified properties
31
+ print(unified_x.shape) # (7,) - combined shape
32
+ print(unified_x.min) # Combined bounds: [-10, -10, -10, -5, -5, -5, 0.1]
33
+ print(unified_x.true) # Access only user-defined states
34
+
35
+ Accessing slices after unification::
36
+
37
+ # After unification, each State has a slice assigned
38
+ print(position._slice) # slice(0, 3)
39
+ print(velocity._slice) # slice(3, 6)
40
+ print(mass._slice) # slice(6, 7)
41
+
42
+ # During lowering, these slices extract values from unified vector
43
+ x_unified = jnp.array([1, 2, 3, 4, 5, 6, 7])
44
+ position_val = x_unified[position._slice] # [1, 2, 3]
45
+
46
+ See Also:
47
+ - UnifiedState: Dataclass for unified state representation (in openscvx.lowered.unified)
48
+ - UnifiedControl: Dataclass for unified control representation (in openscvx.lowered.unified)
49
+ - State: Individual symbolic state variable (symbolic/expr/state.py)
50
+ - Control: Individual symbolic control variable (symbolic/expr/control.py)
51
+ """
52
+
53
+ from typing import List
54
+
55
+ import numpy as np
56
+
57
+ from openscvx.lowered.unified import UnifiedControl, UnifiedState
58
+ from openscvx.symbolic.expr.control import Control
59
+ from openscvx.symbolic.expr.state import State
60
+
61
+ # Re-export for backwards compatibility
62
+ __all__ = ["unify_states", "unify_controls", "UnifiedState", "UnifiedControl"]
63
+
64
+
65
+ def unify_states(states: List[State], name: str = "unified_state") -> UnifiedState:
66
+ """Create a UnifiedState from a list of State objects.
67
+
68
+ This function is the primary way to aggregate multiple symbolic State objects into
69
+ a single unified state vector for numerical optimization. It:
70
+
71
+ 1. Sorts states (user-defined first, augmented states second)
72
+ 2. Concatenates all state properties (bounds, guesses, boundary conditions)
73
+ 3. Assigns slices to each State for extracting values from unified vector
74
+ 4. Identifies special states (time, CTCS augmented states)
75
+ 5. Returns a UnifiedState with all aggregated data
76
+
77
+ Args:
78
+ states (List[State]): List of State objects to unify. Can include both
79
+ user-defined states and augmented states (names starting with '_').
80
+ name (str): Name identifier for the unified state vector (default: "unified_state")
81
+
82
+ Returns:
83
+ UnifiedState: Unified state object containing:
84
+ - Aggregated bounds, guesses, and boundary conditions
85
+ - Shape equal to sum of all state shapes
86
+ - Slices for extracting individual state components
87
+ - Properties for accessing true vs augmented states
88
+
89
+ Example:
90
+ Basic unification::
91
+
92
+ import openscvx as ox
93
+ from openscvx.symbolic.unified import unify_states
94
+
95
+ position = ox.State("pos", shape=(3,), min=-10, max=10)
96
+ velocity = ox.State("vel", shape=(3,), min=-5, max=5)
97
+
98
+ unified = unify_states([position, velocity], name="x")
99
+ print(unified.shape) # (6,)
100
+ print(unified._true_dim) # 6 (all are user states)
101
+ print(position._slice) # slice(0, 3) - assigned during unification
102
+ print(velocity._slice) # slice(3, 6)
103
+
104
+ With augmented states::
105
+
106
+ # CTCS or other features may add augmented states
107
+ time_state = ox.State("time", shape=(1,))
108
+ ctcs_aug = ox.State("_ctcs_aug_0", shape=(2,)) # Augmented state
109
+
110
+ unified = unify_states([position, velocity, time_state, ctcs_aug])
111
+ print(unified._true_dim) # 7 (pos + vel + time)
112
+ print(unified.true.shape) # (7,)
113
+ print(unified.augmented.shape) # (2,) - only CTCS augmented
114
+
115
+ Note:
116
+ After unification, each State object has its `_slice` attribute set,
117
+ which is used during JAX lowering to extract the correct values from
118
+ the unified state vector.
119
+
120
+ See Also:
121
+ - UnifiedState: Return type with detailed documentation
122
+ - unify_controls(): Analogous function for Control objects
123
+ - State: Individual symbolic state variable
124
+ """
125
+ if not states:
126
+ return UnifiedState(name=name, shape=(0,))
127
+
128
+ # Sort states: true states (not starting with '_') first, then augmented states
129
+ # (starting with '_')
130
+ true_states = [state for state in states if not state.name.startswith("_")]
131
+ augmented_states = [state for state in states if state.name.startswith("_")]
132
+ sorted_states = true_states + augmented_states
133
+
134
+ # Calculate total shape
135
+ total_shape = sum(state.shape[0] for state in sorted_states)
136
+
137
+ # Concatenate all arrays, handling None values properly
138
+ min_arrays = []
139
+ max_arrays = []
140
+ guess_arrays = []
141
+ initial_arrays = []
142
+ final_arrays = []
143
+ _initial_arrays = []
144
+ _final_arrays = []
145
+ initial_type_arrays = []
146
+ final_type_arrays = []
147
+
148
+ for state in sorted_states:
149
+ if state.min is not None:
150
+ min_arrays.append(state.min)
151
+ else:
152
+ # If min is None, fill with -inf for this state's dimensions
153
+ min_arrays.append(np.full(state.shape[0], -np.inf))
154
+
155
+ if state.max is not None:
156
+ max_arrays.append(state.max)
157
+ else:
158
+ # If max is None, fill with +inf for this state's dimensions
159
+ max_arrays.append(np.full(state.shape[0], np.inf))
160
+
161
+ if state.guess is not None:
162
+ guess_arrays.append(state.guess)
163
+ if state.initial is not None:
164
+ initial_arrays.append(state.initial)
165
+ if state.final is not None:
166
+ final_arrays.append(state.final)
167
+ if state._initial is not None:
168
+ _initial_arrays.append(state._initial)
169
+ if state._final is not None:
170
+ _final_arrays.append(state._final)
171
+ if state.initial_type is not None:
172
+ initial_type_arrays.append(state.initial_type)
173
+ else:
174
+ # If initial_type is None, fill with "Free" for this state's dimensions
175
+ initial_type_arrays.append(np.full(state.shape[0], "Free", dtype=object))
176
+
177
+ if state.final_type is not None:
178
+ final_type_arrays.append(state.final_type)
179
+ else:
180
+ # If final_type is None, fill with "Free" for this state's dimensions
181
+ final_type_arrays.append(np.full(state.shape[0], "Free", dtype=object))
182
+
183
+ # Concatenate arrays if they exist
184
+ unified_min = np.concatenate(min_arrays) if min_arrays else None
185
+ unified_max = np.concatenate(max_arrays) if max_arrays else None
186
+ unified_guess = np.concatenate(guess_arrays, axis=1) if guess_arrays else None
187
+ unified_initial = np.concatenate(initial_arrays) if initial_arrays else None
188
+ unified_final = np.concatenate(final_arrays) if final_arrays else None
189
+ unified__initial = np.concatenate(_initial_arrays) if _initial_arrays else None
190
+ unified__final = np.concatenate(_final_arrays) if _final_arrays else None
191
+ unified_initial_type = np.concatenate(initial_type_arrays) if initial_type_arrays else None
192
+ unified_final_type = np.concatenate(final_type_arrays) if final_type_arrays else None
193
+
194
+ # Calculate true dimension (only from user-defined states, not augmented ones)
195
+ # Since we simplified State/Control classes, all user states are "true" dimensions
196
+ true_dim = sum(state.shape[0] for state in true_states)
197
+
198
+ # Find time state slice
199
+ time_state = next((s for s in sorted_states if s.name == "time"), None)
200
+ time_slice = time_state._slice if time_state else None
201
+
202
+ # Find CTCS augmented states slice
203
+ ctcs_states = [s for s in sorted_states if s.name.startswith("_ctcs_aug_")]
204
+ ctcs_slice = (
205
+ slice(ctcs_states[0]._slice.start, ctcs_states[-1]._slice.stop) if ctcs_states else None
206
+ )
207
+
208
+ # Aggregate scaling_min and scaling_max from individual states
209
+ # Build full arrays using scaling where available, min/max otherwise
210
+ unified_scaling_min = None
211
+ unified_scaling_max = None
212
+
213
+ # Check if any state has scaling
214
+ has_any_scaling = any(
215
+ state.scaling_min is not None or state.scaling_max is not None for state in sorted_states
216
+ )
217
+
218
+ if has_any_scaling:
219
+ # Build full scaling arrays
220
+ scaling_min_list = []
221
+ scaling_max_list = []
222
+ for state in sorted_states:
223
+ if state.scaling_min is not None:
224
+ scaling_min_list.append(state.scaling_min)
225
+ else:
226
+ # Use min as fallback
227
+ if state.min is not None:
228
+ scaling_min_list.append(state.min)
229
+ else:
230
+ scaling_min_list.append(np.full(state.shape[0], -np.inf))
231
+
232
+ if state.scaling_max is not None:
233
+ scaling_max_list.append(state.scaling_max)
234
+ else:
235
+ # Use max as fallback
236
+ if state.max is not None:
237
+ scaling_max_list.append(state.max)
238
+ else:
239
+ scaling_max_list.append(np.full(state.shape[0], np.inf))
240
+
241
+ unified_scaling_min = np.concatenate(scaling_min_list)
242
+ unified_scaling_max = np.concatenate(scaling_max_list)
243
+
244
+ return UnifiedState(
245
+ name=name,
246
+ shape=(total_shape,),
247
+ min=unified_min,
248
+ max=unified_max,
249
+ guess=unified_guess,
250
+ initial=unified_initial,
251
+ final=unified_final,
252
+ _initial=unified__initial,
253
+ _final=unified__final,
254
+ initial_type=unified_initial_type,
255
+ final_type=unified_final_type,
256
+ _true_dim=true_dim,
257
+ _true_slice=slice(0, true_dim),
258
+ _augmented_slice=slice(true_dim, total_shape),
259
+ time_slice=time_slice,
260
+ ctcs_slice=ctcs_slice,
261
+ scaling_min=unified_scaling_min,
262
+ scaling_max=unified_scaling_max,
263
+ )
264
+
265
+
266
+ def unify_controls(controls: List[Control], name: str = "unified_control") -> UnifiedControl:
267
+ """Create a UnifiedControl from a list of Control objects.
268
+
269
+ This function is the primary way to aggregate multiple symbolic Control objects into
270
+ a single unified control vector for numerical optimization. It:
271
+
272
+ 1. Sorts controls (user-defined first, augmented controls second)
273
+ 2. Concatenates all control properties (bounds, guesses)
274
+ 3. Assigns slices to each Control for extracting values from unified vector
275
+ 4. Identifies special controls (time dilation)
276
+ 5. Returns a UnifiedControl with all aggregated data
277
+
278
+ Args:
279
+ controls (List[Control]): List of Control objects to unify. Can include both
280
+ user-defined controls and augmented controls (names starting with '_').
281
+ name (str): Name identifier for the unified control vector (default: "unified_control")
282
+
283
+ Returns:
284
+ UnifiedControl: Unified control object containing:
285
+ - Aggregated bounds and guesses
286
+ - Shape equal to sum of all control shapes
287
+ - Slices for extracting individual control components
288
+ - Properties for accessing true vs augmented controls
289
+
290
+ Example:
291
+ Basic unification::
292
+
293
+ import openscvx as ox
294
+ from openscvx.symbolic.unified import unify_controls
295
+
296
+ thrust = ox.Control("thrust", shape=(3,), min=0, max=10)
297
+ torque = ox.Control("torque", shape=(3,), min=-1, max=1)
298
+
299
+ unified = unify_controls([thrust, torque], name="u")
300
+ print(unified.shape) # (6,)
301
+ print(unified._true_dim) # 6 (all are user controls)
302
+ print(thrust._slice) # slice(0, 3) - assigned during unification
303
+ print(torque._slice) # slice(3, 6)
304
+
305
+ With augmented controls::
306
+
307
+ # Time-optimal problems may add time dilation control
308
+ time_dilation = ox.Control("_time_dilation", shape=(1,))
309
+
310
+ unified = unify_controls([thrust, torque, time_dilation])
311
+ print(unified._true_dim) # 6 (thrust + torque)
312
+ print(unified.true.shape) # (6,)
313
+ print(unified.augmented.shape) # (1,) - time dilation
314
+
315
+ Note:
316
+ After unification, each Control object has its `_slice` attribute set,
317
+ which is used during JAX lowering to extract the correct values from
318
+ the unified control vector.
319
+
320
+ See Also:
321
+ - UnifiedControl: Return type with detailed documentation
322
+ - unify_states(): Analogous function for State objects
323
+ - Control: Individual symbolic control variable
324
+ """
325
+ if not controls:
326
+ return UnifiedControl(name=name, shape=(0,))
327
+
328
+ # Sort controls: true controls (not starting with '_') first, then augmented controls
329
+ # (starting with '_')
330
+ true_controls = [control for control in controls if not control.name.startswith("_")]
331
+ augmented_controls = [control for control in controls if control.name.startswith("_")]
332
+ sorted_controls = true_controls + augmented_controls
333
+
334
+ # Calculate total shape
335
+ total_shape = sum(control.shape[0] for control in sorted_controls)
336
+
337
+ # Concatenate all arrays, handling None values properly
338
+ min_arrays = []
339
+ max_arrays = []
340
+ guess_arrays = []
341
+
342
+ for control in sorted_controls:
343
+ if control.min is not None:
344
+ min_arrays.append(control.min)
345
+ else:
346
+ # If min is None, fill with -inf for this control's dimensions
347
+ min_arrays.append(np.full(control.shape[0], -np.inf))
348
+
349
+ if control.max is not None:
350
+ max_arrays.append(control.max)
351
+ else:
352
+ # If max is None, fill with +inf for this control's dimensions
353
+ max_arrays.append(np.full(control.shape[0], np.inf))
354
+
355
+ if control.guess is not None:
356
+ guess_arrays.append(control.guess)
357
+
358
+ # Concatenate arrays if they exist
359
+ unified_min = np.concatenate(min_arrays) if min_arrays else None
360
+ unified_max = np.concatenate(max_arrays) if max_arrays else None
361
+ unified_guess = np.concatenate(guess_arrays, axis=1) if guess_arrays else None
362
+
363
+ # Calculate true dimension (only from user-defined controls, not augmented ones)
364
+ # Since we simplified State/Control classes, all user controls are "true" dimensions
365
+ true_dim = sum(control.shape[0] for control in true_controls)
366
+
367
+ # Find time dilation control slice
368
+ time_dilation_control = next((c for c in sorted_controls if c.name == "_time_dilation"), None)
369
+ time_dilation_slice = time_dilation_control._slice if time_dilation_control else None
370
+
371
+ # Aggregate scaling_min and scaling_max from individual controls
372
+ # Build full arrays using scaling where available, min/max otherwise
373
+ unified_scaling_min = None
374
+ unified_scaling_max = None
375
+
376
+ # Check if any control has scaling
377
+ has_any_scaling = any(
378
+ control.scaling_min is not None or control.scaling_max is not None
379
+ for control in sorted_controls
380
+ )
381
+
382
+ if has_any_scaling:
383
+ # Build full scaling arrays
384
+ scaling_min_list = []
385
+ scaling_max_list = []
386
+ for control in sorted_controls:
387
+ if control.scaling_min is not None:
388
+ scaling_min_list.append(control.scaling_min)
389
+ else:
390
+ # Use min as fallback
391
+ if control.min is not None:
392
+ scaling_min_list.append(control.min)
393
+ else:
394
+ scaling_min_list.append(np.full(control.shape[0], -np.inf))
395
+
396
+ if control.scaling_max is not None:
397
+ scaling_max_list.append(control.scaling_max)
398
+ else:
399
+ # Use max as fallback
400
+ if control.max is not None:
401
+ scaling_max_list.append(control.max)
402
+ else:
403
+ scaling_max_list.append(np.full(control.shape[0], np.inf))
404
+
405
+ unified_scaling_min = np.concatenate(scaling_min_list)
406
+ unified_scaling_max = np.concatenate(scaling_max_list)
407
+
408
+ return UnifiedControl(
409
+ name=name,
410
+ shape=(total_shape,),
411
+ min=unified_min,
412
+ max=unified_max,
413
+ guess=unified_guess,
414
+ _true_dim=true_dim,
415
+ _true_slice=slice(0, true_dim),
416
+ _augmented_slice=slice(true_dim, total_shape),
417
+ time_dilation_slice=time_dilation_slice,
418
+ scaling_min=unified_scaling_min,
419
+ scaling_max=unified_scaling_max,
420
+ )
@@ -0,0 +1,20 @@
1
+ """Utility functions for caching, printing, and output formatting.
2
+
3
+ This module provides utilities for OpenSCvx.
4
+ """
5
+
6
+ from .utils import (
7
+ calculate_cost_from_boundaries,
8
+ gen_vertices,
9
+ generate_orthogonal_unit_vectors,
10
+ get_kp_pose,
11
+ rot,
12
+ )
13
+
14
+ __all__ = [
15
+ "generate_orthogonal_unit_vectors",
16
+ "rot",
17
+ "gen_vertices",
18
+ "get_kp_pose",
19
+ "calculate_cost_from_boundaries",
20
+ ]