openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,492 @@
|
|
|
1
|
+
"""Symbolic problem preprocessing and augmentation pipeline.
|
|
2
|
+
|
|
3
|
+
This module provides the main preprocessing pipeline for trajectory optimization problems,
|
|
4
|
+
transforming user-specified symbolic dynamics and constraints into an augmented form
|
|
5
|
+
ready for compilation to executable code.
|
|
6
|
+
|
|
7
|
+
The preprocessing pipeline is purely symbolic - no code generation occurs here. Instead,
|
|
8
|
+
it performs validation, canonicalization, and augmentation to prepare the problem for
|
|
9
|
+
efficient numerical solution.
|
|
10
|
+
|
|
11
|
+
Key functionality:
|
|
12
|
+
- Problem validation: Check shapes, variable names, constraint placement
|
|
13
|
+
- Time handling: Auto-create time state or validate user-provided time
|
|
14
|
+
- Canonicalization: Simplify expressions algebraically
|
|
15
|
+
- Parameter collection: Extract parameter values from expressions
|
|
16
|
+
- Constraint separation: Categorize constraints by type (CTCS, nodal, convex)
|
|
17
|
+
- CTCS augmentation: Add augmented states and time dilation for path constraints
|
|
18
|
+
- Propagation dynamics: Optionally extend dynamics for post-solution propagation
|
|
19
|
+
|
|
20
|
+
The preprocessing pipeline is purely symbolic - no code generation occurs here.
|
|
21
|
+
|
|
22
|
+
Pipeline stages:
|
|
23
|
+
1. Time handling & validation
|
|
24
|
+
2. Expression validation (shapes, names, constraint structure)
|
|
25
|
+
3. Canonicalization & parameter collection
|
|
26
|
+
4. Constraint separation & CTCS augmentation
|
|
27
|
+
5. Propagation dynamics creation
|
|
28
|
+
|
|
29
|
+
See `preprocess_symbolic_problem()` for the main entry point.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
from typing import Dict, List, Optional, Tuple
|
|
33
|
+
|
|
34
|
+
import numpy as np
|
|
35
|
+
|
|
36
|
+
from openscvx.symbolic.augmentation import (
|
|
37
|
+
augment_dynamics_with_ctcs,
|
|
38
|
+
augment_with_time_state,
|
|
39
|
+
decompose_vector_nodal_constraints,
|
|
40
|
+
separate_constraints,
|
|
41
|
+
sort_ctcs_constraints,
|
|
42
|
+
)
|
|
43
|
+
from openscvx.symbolic.constraint_set import ConstraintSet
|
|
44
|
+
from openscvx.symbolic.expr import Constant, Parameter, traverse
|
|
45
|
+
from openscvx.symbolic.expr.control import Control
|
|
46
|
+
from openscvx.symbolic.expr.state import State
|
|
47
|
+
from openscvx.symbolic.preprocessing import (
|
|
48
|
+
collect_and_assign_slices,
|
|
49
|
+
convert_dynamics_dict_to_expr,
|
|
50
|
+
validate_and_normalize_constraint_nodes,
|
|
51
|
+
validate_constraints_at_root,
|
|
52
|
+
validate_dynamics_dict,
|
|
53
|
+
validate_dynamics_dict_dimensions,
|
|
54
|
+
validate_dynamics_dimension,
|
|
55
|
+
validate_shapes,
|
|
56
|
+
validate_time_parameters,
|
|
57
|
+
validate_variable_names,
|
|
58
|
+
)
|
|
59
|
+
from openscvx.symbolic.problem import SymbolicProblem
|
|
60
|
+
from openscvx.symbolic.time import Time
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def preprocess_symbolic_problem(
|
|
64
|
+
dynamics: dict,
|
|
65
|
+
constraints: ConstraintSet,
|
|
66
|
+
states: List[State],
|
|
67
|
+
controls: List[Control],
|
|
68
|
+
N: int,
|
|
69
|
+
time: Time,
|
|
70
|
+
licq_min: float = 0.0,
|
|
71
|
+
licq_max: float = 1e-4,
|
|
72
|
+
time_dilation_factor_min: float = 0.3,
|
|
73
|
+
time_dilation_factor_max: float = 3.0,
|
|
74
|
+
dynamics_prop_extra: dict = None,
|
|
75
|
+
states_prop_extra: List[State] = None,
|
|
76
|
+
byof: Optional[dict] = None,
|
|
77
|
+
) -> SymbolicProblem:
|
|
78
|
+
"""Preprocess and augment symbolic trajectory optimization problem.
|
|
79
|
+
|
|
80
|
+
This is the main preprocessing pipeline that transforms a user-specified symbolic
|
|
81
|
+
problem into an augmented form ready for compilation. It performs validation,
|
|
82
|
+
canonicalization, constraint separation, and CTCS augmentation in a series of
|
|
83
|
+
well-defined phases.
|
|
84
|
+
|
|
85
|
+
The function is purely symbolic - no code generation or compilation occurs. The
|
|
86
|
+
output is a SymbolicProblem dataclass that can be lowered to JAX or CVXPy by
|
|
87
|
+
downstream compilation functions.
|
|
88
|
+
|
|
89
|
+
Pipeline phases:
|
|
90
|
+
1. Time handling & validation: Auto-create or validate time state
|
|
91
|
+
2. Expression validation: Validate shapes, names, constraints
|
|
92
|
+
3. Canonicalization & parameter collection: Simplify and extract parameters
|
|
93
|
+
4. Constraint separation & augmentation: Sort constraints and add CTCS states
|
|
94
|
+
5. Propagation dynamics creation: Optionally add extra states for simulation
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
dynamics: Dictionary mapping state names to dynamics expressions.
|
|
98
|
+
Example: {"x": v, "v": u}
|
|
99
|
+
constraints: ConstraintSet with raw constraints in `unsorted` field.
|
|
100
|
+
Create with: ConstraintSet(unsorted=[c1, c2, c3])
|
|
101
|
+
states: List of user-defined State objects (should NOT include time or CTCS states)
|
|
102
|
+
controls: List of user-defined Control objects (should NOT include time dilation)
|
|
103
|
+
N: Number of discretization nodes in the trajectory
|
|
104
|
+
time: Time configuration object specifying time bounds and constraints
|
|
105
|
+
licq_min: Minimum bound for CTCS augmented states (default: 0.0)
|
|
106
|
+
licq_max: Maximum bound for CTCS augmented states (default: 1e-4)
|
|
107
|
+
time_dilation_factor_min: Minimum factor for time dilation control (default: 0.3)
|
|
108
|
+
time_dilation_factor_max: Maximum factor for time dilation control (default: 3.0)
|
|
109
|
+
dynamics_prop_extra: Optional dictionary of additional dynamics for propagation-only
|
|
110
|
+
states (default: None)
|
|
111
|
+
states_prop_extra: Optional list of additional State objects for propagation only
|
|
112
|
+
(default: None)
|
|
113
|
+
byof: Optional dict of raw JAX functions for expert users. If byof contains
|
|
114
|
+
a "dynamics" key, it should map state names to raw JAX functions with
|
|
115
|
+
signature f(x, u, node, params) -> xdot_component. States in byof["dynamics"]
|
|
116
|
+
should NOT appear in the symbolic dynamics dict.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
SymbolicProblem dataclass with:
|
|
120
|
+
- dynamics: Augmented dynamics (user + time + CTCS penalties)
|
|
121
|
+
- states: Augmented states (user + time + CTCS augmented)
|
|
122
|
+
- controls: Augmented controls (user + time dilation)
|
|
123
|
+
- constraints: ConstraintSet with is_categorized=True
|
|
124
|
+
- parameters: Dict of extracted parameter values
|
|
125
|
+
- node_intervals: List of (start, end) tuples for CTCS intervals
|
|
126
|
+
- dynamics_prop: Propagation dynamics
|
|
127
|
+
- states_prop: Propagation states
|
|
128
|
+
- controls_prop: Propagation controls
|
|
129
|
+
|
|
130
|
+
Raises:
|
|
131
|
+
ValueError: If validation fails at any stage
|
|
132
|
+
|
|
133
|
+
Example:
|
|
134
|
+
Basic usage with CTCS constraint::
|
|
135
|
+
|
|
136
|
+
import openscvx as ox
|
|
137
|
+
from openscvx.symbolic.constraint_set import ConstraintSet
|
|
138
|
+
|
|
139
|
+
x = ox.State("x", shape=(2,))
|
|
140
|
+
v = ox.State("v", shape=(2,))
|
|
141
|
+
u = ox.Control("u", shape=(2,))
|
|
142
|
+
|
|
143
|
+
dynamics = {"x": v, "v": u}
|
|
144
|
+
constraints = ConstraintSet(unsorted=[
|
|
145
|
+
(ox.Norm(x) <= 5.0).over((0, 50))
|
|
146
|
+
])
|
|
147
|
+
|
|
148
|
+
problem = preprocess_symbolic_problem(
|
|
149
|
+
dynamics=dynamics,
|
|
150
|
+
constraints=constraints,
|
|
151
|
+
states=[x, v],
|
|
152
|
+
controls=[u],
|
|
153
|
+
N=50,
|
|
154
|
+
time=ox.Time(initial=0.0, final=10.0)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
assert problem.is_preprocessed
|
|
158
|
+
# problem.dynamics: augmented dynamics expression
|
|
159
|
+
# problem.states: [x, v, time, _ctcs_aug_0]
|
|
160
|
+
# problem.controls: [u, _time_dilation]
|
|
161
|
+
print([s.name for s in problem.states])
|
|
162
|
+
# ['x', 'v', 'time', '_ctcs_aug_0']
|
|
163
|
+
|
|
164
|
+
With propagation-only states::
|
|
165
|
+
|
|
166
|
+
distance = ox.State("distance", shape=(1,))
|
|
167
|
+
dynamics_extra = {"distance": ox.Norm(v)}
|
|
168
|
+
|
|
169
|
+
problem = preprocess_symbolic_problem(
|
|
170
|
+
dynamics=dynamics,
|
|
171
|
+
constraints=constraints,
|
|
172
|
+
states=[x, v],
|
|
173
|
+
controls=[u],
|
|
174
|
+
N=50,
|
|
175
|
+
time=ox.Time(initial=0.0, final=10.0),
|
|
176
|
+
dynamics_prop_extra=dynamics_extra,
|
|
177
|
+
states_prop_extra=[distance]
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Propagation states include distance for post-solve simulation
|
|
181
|
+
print([s.name for s in problem.states_prop])
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
# ==================== PHASE 1: Time Handling & Validation ====================
|
|
185
|
+
|
|
186
|
+
# Validate time handling approach and get processed parameters
|
|
187
|
+
(
|
|
188
|
+
has_time_state,
|
|
189
|
+
time_initial,
|
|
190
|
+
time_final,
|
|
191
|
+
time_derivative,
|
|
192
|
+
time_min,
|
|
193
|
+
time_max,
|
|
194
|
+
) = validate_time_parameters(states, time)
|
|
195
|
+
|
|
196
|
+
# Augment states with time state if needed (auto-create approach)
|
|
197
|
+
if not has_time_state:
|
|
198
|
+
states, constraints = augment_with_time_state(
|
|
199
|
+
states,
|
|
200
|
+
constraints,
|
|
201
|
+
time_initial,
|
|
202
|
+
time_final,
|
|
203
|
+
time_min,
|
|
204
|
+
time_max,
|
|
205
|
+
N,
|
|
206
|
+
time_scaling_min=getattr(time, "scaling_min", None),
|
|
207
|
+
time_scaling_max=getattr(time, "scaling_max", None),
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Add time derivative to dynamics dict (if not already present)
|
|
211
|
+
# Time derivative is always 1.0 when using Time object
|
|
212
|
+
dynamics = dict(dynamics) # Make a copy to avoid mutating the input
|
|
213
|
+
if "time" not in dynamics:
|
|
214
|
+
dynamics["time"] = 1.0
|
|
215
|
+
|
|
216
|
+
# Extract byof dynamics for validation
|
|
217
|
+
byof_dynamics = byof.get("dynamics", {}) if byof else {}
|
|
218
|
+
|
|
219
|
+
# Validate dynamics dict matches state names and dimensions
|
|
220
|
+
# byof_dynamics states should not be in symbolic dynamics dict
|
|
221
|
+
validate_dynamics_dict(dynamics, states, byof_dynamics=byof_dynamics)
|
|
222
|
+
|
|
223
|
+
# Inject zero placeholders for byof dynamics states
|
|
224
|
+
# These will be replaced with the actual byof functions at lowering time
|
|
225
|
+
for state in states:
|
|
226
|
+
if state.name in byof_dynamics:
|
|
227
|
+
dynamics[state.name] = Constant(np.zeros(state.shape))
|
|
228
|
+
|
|
229
|
+
# Validate dynamics dimensions AFTER injecting placeholders
|
|
230
|
+
validate_dynamics_dict_dimensions(dynamics, states)
|
|
231
|
+
|
|
232
|
+
# Convert dynamics dict to concatenated expression
|
|
233
|
+
dynamics, dynamics_concat = convert_dynamics_dict_to_expr(dynamics, states)
|
|
234
|
+
|
|
235
|
+
# ==================== PHASE 2: Expression Validation ====================
|
|
236
|
+
|
|
237
|
+
# Validate all expressions (use unsorted constraints)
|
|
238
|
+
all_exprs = [dynamics_concat] + constraints.unsorted
|
|
239
|
+
validate_variable_names(all_exprs)
|
|
240
|
+
collect_and_assign_slices(states, controls)
|
|
241
|
+
validate_shapes(all_exprs)
|
|
242
|
+
validate_constraints_at_root(constraints.unsorted)
|
|
243
|
+
validate_and_normalize_constraint_nodes(constraints.unsorted, N)
|
|
244
|
+
validate_dynamics_dimension(dynamics_concat, states)
|
|
245
|
+
|
|
246
|
+
# ==================== PHASE 3: Canonicalization & Parameter Collection ====================
|
|
247
|
+
|
|
248
|
+
# Canonicalize all expressions after validation
|
|
249
|
+
dynamics_concat = dynamics_concat.canonicalize()
|
|
250
|
+
constraints.unsorted = [expr.canonicalize() for expr in constraints.unsorted]
|
|
251
|
+
|
|
252
|
+
# Collect parameter values from all constraints and dynamics
|
|
253
|
+
parameters = {}
|
|
254
|
+
|
|
255
|
+
def collect_param_values(expr):
|
|
256
|
+
if isinstance(expr, Parameter):
|
|
257
|
+
if expr.name not in parameters:
|
|
258
|
+
parameters[expr.name] = expr.value
|
|
259
|
+
|
|
260
|
+
# Collect from dynamics
|
|
261
|
+
traverse(dynamics_concat, collect_param_values)
|
|
262
|
+
|
|
263
|
+
# Collect from constraints
|
|
264
|
+
for constraint in constraints.unsorted:
|
|
265
|
+
traverse(constraint, collect_param_values)
|
|
266
|
+
|
|
267
|
+
# ==================== PHASE 4: Constraint Separation & Augmentation ====================
|
|
268
|
+
|
|
269
|
+
# Sort and separate constraints by type (drains unsorted -> fills categories)
|
|
270
|
+
separate_constraints(constraints, N)
|
|
271
|
+
|
|
272
|
+
# Decompose vector-valued nodal constraints into scalar constraints
|
|
273
|
+
# This is necessary for non-convex nodal constraints that get lowered to JAX
|
|
274
|
+
constraints.nodal = decompose_vector_nodal_constraints(constraints.nodal)
|
|
275
|
+
|
|
276
|
+
# Sort CTCS constraints by their idx to get node_intervals
|
|
277
|
+
constraints.ctcs, node_intervals, _ = sort_ctcs_constraints(constraints.ctcs)
|
|
278
|
+
|
|
279
|
+
# Augment dynamics, states, and controls with CTCS constraints, time dilation
|
|
280
|
+
dynamics_aug, states_aug, controls_aug = augment_dynamics_with_ctcs(
|
|
281
|
+
dynamics_concat,
|
|
282
|
+
states,
|
|
283
|
+
controls,
|
|
284
|
+
constraints.ctcs,
|
|
285
|
+
N,
|
|
286
|
+
licq_min=licq_min,
|
|
287
|
+
licq_max=licq_max,
|
|
288
|
+
time_dilation_factor_min=time_dilation_factor_min,
|
|
289
|
+
time_dilation_factor_max=time_dilation_factor_max,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# Assign slices to augmented states and controls in canonical order
|
|
293
|
+
collect_and_assign_slices(states_aug, controls_aug)
|
|
294
|
+
|
|
295
|
+
# ==================== PHASE 5: Create Propagation Dynamics ====================
|
|
296
|
+
|
|
297
|
+
# By default, propagation dynamics are the same as optimization dynamics
|
|
298
|
+
# Use deepcopy to avoid reference issues when lowering
|
|
299
|
+
from copy import deepcopy
|
|
300
|
+
|
|
301
|
+
dynamics_prop = deepcopy(dynamics_aug)
|
|
302
|
+
states_prop = list(states_aug) # Shallow copy of list is fine for states
|
|
303
|
+
controls_prop = list(controls_aug)
|
|
304
|
+
|
|
305
|
+
# If user provided extra propagation states, extend propagation dynamics
|
|
306
|
+
if dynamics_prop_extra is not None and states_prop_extra is not None:
|
|
307
|
+
(
|
|
308
|
+
dynamics_prop,
|
|
309
|
+
states_prop,
|
|
310
|
+
controls_prop,
|
|
311
|
+
parameters,
|
|
312
|
+
) = add_propagation_states(
|
|
313
|
+
dynamics_extra=dynamics_prop_extra,
|
|
314
|
+
states_extra=states_prop_extra,
|
|
315
|
+
dynamics_opt=dynamics_prop,
|
|
316
|
+
states_opt=states_prop,
|
|
317
|
+
controls_opt=controls_prop,
|
|
318
|
+
parameters=parameters,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# ==================== Return SymbolicProblem ====================
|
|
322
|
+
|
|
323
|
+
return SymbolicProblem(
|
|
324
|
+
dynamics=dynamics_aug,
|
|
325
|
+
states=states_aug,
|
|
326
|
+
controls=controls_aug,
|
|
327
|
+
constraints=constraints,
|
|
328
|
+
parameters=parameters,
|
|
329
|
+
N=N,
|
|
330
|
+
node_intervals=node_intervals,
|
|
331
|
+
dynamics_prop=dynamics_prop,
|
|
332
|
+
states_prop=states_prop,
|
|
333
|
+
controls_prop=controls_prop,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def add_propagation_states(
|
|
338
|
+
dynamics_extra: dict,
|
|
339
|
+
states_extra: List[State],
|
|
340
|
+
dynamics_opt: any,
|
|
341
|
+
states_opt: List[State],
|
|
342
|
+
controls_opt: List[Control],
|
|
343
|
+
parameters: Dict[str, any],
|
|
344
|
+
) -> Tuple:
|
|
345
|
+
"""Extend optimization dynamics with additional propagation-only states.
|
|
346
|
+
|
|
347
|
+
This function augments the optimization dynamics with extra states that are only
|
|
348
|
+
needed for post-solution trajectory propagation and simulation. These states
|
|
349
|
+
don't affect the optimization but are useful for computing derived quantities
|
|
350
|
+
like distance traveled, energy consumed, or accumulated cost.
|
|
351
|
+
|
|
352
|
+
Propagation-only states are NOT part of the optimization problem - they are
|
|
353
|
+
integrated forward after solving using the optimized state and control trajectories.
|
|
354
|
+
This is more efficient than including them as optimization variables.
|
|
355
|
+
|
|
356
|
+
The user specifies only the ADDITIONAL states and their dynamics. These are
|
|
357
|
+
appended after all optimization states (user states + time + CTCS augmented states).
|
|
358
|
+
|
|
359
|
+
State ordering in propagation dynamics:
|
|
360
|
+
[user_states, time, ctcs_aug_states, extra_prop_states]
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
dynamics_extra: Dictionary mapping extra state names to dynamics expressions.
|
|
364
|
+
Only specify NEW states, not optimization states. Example: {"distance": speed}
|
|
365
|
+
states_extra: List of extra State objects for propagation only
|
|
366
|
+
dynamics_opt: Augmented optimization dynamics expression (from preprocessing)
|
|
367
|
+
states_opt: Augmented optimization states (user + time + CTCS augmented)
|
|
368
|
+
controls_opt: Augmented optimization controls (user + time dilation)
|
|
369
|
+
parameters: Dictionary of parameter values from optimization preprocessing
|
|
370
|
+
|
|
371
|
+
Returns:
|
|
372
|
+
Tuple containing:
|
|
373
|
+
- dynamics_prop (Expr): Extended dynamics (optimization + extra)
|
|
374
|
+
- states_prop (List[State]): Extended states (optimization + extra)
|
|
375
|
+
- controls_prop (List[Control]): Same as controls_opt
|
|
376
|
+
- parameters_updated (Dict): Updated parameters including any from extra dynamics
|
|
377
|
+
|
|
378
|
+
Raises:
|
|
379
|
+
ValueError: If extra states conflict with optimization state names or if
|
|
380
|
+
validation fails
|
|
381
|
+
|
|
382
|
+
Example:
|
|
383
|
+
Adding distance and energy tracking for propagation::
|
|
384
|
+
|
|
385
|
+
# After preprocessing, add propagation states
|
|
386
|
+
import openscvx as ox
|
|
387
|
+
import numpy as np
|
|
388
|
+
|
|
389
|
+
# Define extra states for tracking
|
|
390
|
+
distance = ox.State("distance", shape=(1,))
|
|
391
|
+
distance.initial = np.array([0.0])
|
|
392
|
+
|
|
393
|
+
energy = ox.State("energy", shape=(1,))
|
|
394
|
+
energy.initial = np.array([0.0])
|
|
395
|
+
|
|
396
|
+
# Define their dynamics (using optimization states/controls)
|
|
397
|
+
# Assume v and u are optimization states/controls
|
|
398
|
+
dynamics_extra = {
|
|
399
|
+
"distance": ox.Norm(v), # Integrate velocity magnitude
|
|
400
|
+
"energy": ox.Norm(u)**2 # Integrate squared control
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
dyn_prop, states_prop, controls_prop, params = add_propagation_states(
|
|
404
|
+
dynamics_extra=dynamics_extra,
|
|
405
|
+
states_extra=[distance, energy],
|
|
406
|
+
dynamics_opt=dynamics_aug,
|
|
407
|
+
states_opt=states_aug,
|
|
408
|
+
controls_opt=controls_aug,
|
|
409
|
+
parameters=parameters
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
# Now states_prop includes all states for forward simulation
|
|
413
|
+
# distance and energy will be integrated during propagation
|
|
414
|
+
|
|
415
|
+
Note:
|
|
416
|
+
The extra states should have initial conditions set, as they will be
|
|
417
|
+
integrated from these initial values during propagation.
|
|
418
|
+
"""
|
|
419
|
+
|
|
420
|
+
# Make copies to avoid mutating inputs
|
|
421
|
+
states_extra = list(states_extra)
|
|
422
|
+
dynamics_extra = dict(dynamics_extra)
|
|
423
|
+
parameters = dict(parameters)
|
|
424
|
+
|
|
425
|
+
# ==================== PHASE 1: Validate Extra States ====================
|
|
426
|
+
|
|
427
|
+
# Validate that extra states don't conflict with optimization state names
|
|
428
|
+
opt_state_names = {s.name for s in states_opt}
|
|
429
|
+
extra_state_names = {s.name for s in states_extra}
|
|
430
|
+
conflicts = opt_state_names & extra_state_names
|
|
431
|
+
if conflicts:
|
|
432
|
+
raise ValueError(
|
|
433
|
+
f"Extra propagation states conflict with optimization states: {conflicts}. "
|
|
434
|
+
f"Only specify additional states, not optimization states."
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
# Validate dynamics dict for extra states
|
|
438
|
+
validate_dynamics_dict(dynamics_extra, states_extra)
|
|
439
|
+
validate_dynamics_dict_dimensions(dynamics_extra, states_extra)
|
|
440
|
+
|
|
441
|
+
# ==================== PHASE 2: Process Extra Dynamics ====================
|
|
442
|
+
|
|
443
|
+
# Convert extra dynamics to expression
|
|
444
|
+
_, dynamics_extra_concat = convert_dynamics_dict_to_expr(dynamics_extra, states_extra)
|
|
445
|
+
|
|
446
|
+
# Validate and canonicalize
|
|
447
|
+
validate_variable_names([dynamics_extra_concat])
|
|
448
|
+
|
|
449
|
+
# Temporarily assign slices for validation (will be recalculated below)
|
|
450
|
+
collect_and_assign_slices(states_extra, controls_opt)
|
|
451
|
+
validate_shapes([dynamics_extra_concat])
|
|
452
|
+
validate_dynamics_dimension(dynamics_extra_concat, states_extra)
|
|
453
|
+
dynamics_extra_concat = dynamics_extra_concat.canonicalize()
|
|
454
|
+
|
|
455
|
+
# Collect any new parameter values from extra dynamics
|
|
456
|
+
def collect_param_values(expr):
|
|
457
|
+
if isinstance(expr, Parameter):
|
|
458
|
+
if expr.name not in parameters:
|
|
459
|
+
parameters[expr.name] = expr.value
|
|
460
|
+
|
|
461
|
+
traverse(dynamics_extra_concat, collect_param_values)
|
|
462
|
+
|
|
463
|
+
# ==================== PHASE 3: Concatenate with Optimization Dynamics ====================
|
|
464
|
+
|
|
465
|
+
# Concatenate: {opt dynamics, extra dynamics}
|
|
466
|
+
from openscvx.symbolic.expr import Concat
|
|
467
|
+
|
|
468
|
+
dynamics_prop = Concat(dynamics_opt, dynamics_extra_concat)
|
|
469
|
+
|
|
470
|
+
# Manually assign slices to extra states ONLY (don't modify optimization state slices)
|
|
471
|
+
# Extra states are appended after all optimization states
|
|
472
|
+
n_opt_states = states_opt[-1]._slice.stop if states_opt else 0
|
|
473
|
+
start_idx = n_opt_states
|
|
474
|
+
for state in states_extra:
|
|
475
|
+
end_idx = start_idx + state.shape[0]
|
|
476
|
+
state._slice = slice(start_idx, end_idx)
|
|
477
|
+
start_idx = end_idx
|
|
478
|
+
|
|
479
|
+
# Append extra states to optimization states
|
|
480
|
+
states_prop = states_opt + states_extra
|
|
481
|
+
|
|
482
|
+
# Propagation uses same controls as optimization
|
|
483
|
+
controls_prop = controls_opt
|
|
484
|
+
|
|
485
|
+
# ==================== Return Symbolic Outputs ====================
|
|
486
|
+
|
|
487
|
+
return (
|
|
488
|
+
dynamics_prop,
|
|
489
|
+
states_prop,
|
|
490
|
+
controls_prop,
|
|
491
|
+
parameters,
|
|
492
|
+
)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""Container for categorized symbolic constraints.
|
|
2
|
+
|
|
3
|
+
This module provides a dataclass to hold all symbolic constraint types in a
|
|
4
|
+
structured way before they are lowered to JAX/CVXPy.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import TYPE_CHECKING, List, Union
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from openscvx.symbolic.expr import CTCS, Constraint, CrossNodeConstraint, NodalConstraint
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class ConstraintSet:
|
|
16
|
+
"""Container for categorized symbolic constraints.
|
|
17
|
+
|
|
18
|
+
This dataclass holds all symbolic constraint types in a structured way,
|
|
19
|
+
providing type safety and a clear API for accessing constraint categories.
|
|
20
|
+
This is a pre-lowering container - after lowering, constraints live in
|
|
21
|
+
LoweredJaxConstraints and LoweredCvxpyConstraints.
|
|
22
|
+
|
|
23
|
+
The constraint set supports two lifecycle stages:
|
|
24
|
+
|
|
25
|
+
1. **Before preprocessing**: Raw constraints live in `unsorted`
|
|
26
|
+
2. **After preprocessing**: `unsorted` is empty, constraints are categorized
|
|
27
|
+
|
|
28
|
+
Use `is_categorized` to check which stage the constraint set is in.
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
unsorted: Raw constraints before categorization. Empty after preprocessing.
|
|
32
|
+
ctcs: CTCS (continuous-time) constraints.
|
|
33
|
+
nodal: Non-convex nodal constraints (will be lowered to JAX).
|
|
34
|
+
nodal_convex: Convex nodal constraints (will be lowered to CVXPy).
|
|
35
|
+
cross_node: Non-convex cross-node constraints (will be lowered to JAX).
|
|
36
|
+
cross_node_convex: Convex cross-node constraints (will be lowered to CVXPy).
|
|
37
|
+
|
|
38
|
+
Example:
|
|
39
|
+
Before preprocessing (raw constraints)::
|
|
40
|
+
|
|
41
|
+
constraints = ConstraintSet(unsorted=[c1, c2, c3])
|
|
42
|
+
assert not constraints.is_categorized
|
|
43
|
+
|
|
44
|
+
After preprocessing (categorized)::
|
|
45
|
+
|
|
46
|
+
# preprocess_symbolic_problem drains unsorted -> fills categories
|
|
47
|
+
assert constraints.is_categorized
|
|
48
|
+
for c in constraints.nodal:
|
|
49
|
+
# Process non-convex nodal constraints
|
|
50
|
+
pass
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
# Raw constraints before categorization (empty after preprocessing)
|
|
54
|
+
unsorted: List[Union["Constraint", "CTCS"]] = field(default_factory=list)
|
|
55
|
+
|
|
56
|
+
# Categorized symbolic constraints (populated by preprocessing)
|
|
57
|
+
ctcs: List["CTCS"] = field(default_factory=list)
|
|
58
|
+
nodal: List["NodalConstraint"] = field(default_factory=list)
|
|
59
|
+
nodal_convex: List["NodalConstraint"] = field(default_factory=list)
|
|
60
|
+
cross_node: List["CrossNodeConstraint"] = field(default_factory=list)
|
|
61
|
+
cross_node_convex: List["CrossNodeConstraint"] = field(default_factory=list)
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def is_categorized(self) -> bool:
|
|
65
|
+
"""True if all constraints have been sorted into categories.
|
|
66
|
+
|
|
67
|
+
After preprocessing, `unsorted` should be empty and all constraints
|
|
68
|
+
should be in their appropriate category lists.
|
|
69
|
+
"""
|
|
70
|
+
return len(self.unsorted) == 0
|
|
71
|
+
|
|
72
|
+
def __bool__(self) -> bool:
|
|
73
|
+
"""Return True if any constraint list is non-empty."""
|
|
74
|
+
return bool(
|
|
75
|
+
self.unsorted
|
|
76
|
+
or self.ctcs
|
|
77
|
+
or self.nodal
|
|
78
|
+
or self.nodal_convex
|
|
79
|
+
or self.cross_node
|
|
80
|
+
or self.cross_node_convex
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def __len__(self) -> int:
|
|
84
|
+
"""Return total number of constraints across all lists."""
|
|
85
|
+
return (
|
|
86
|
+
len(self.unsorted)
|
|
87
|
+
+ len(self.ctcs)
|
|
88
|
+
+ len(self.nodal)
|
|
89
|
+
+ len(self.nodal_convex)
|
|
90
|
+
+ len(self.cross_node)
|
|
91
|
+
+ len(self.cross_node_convex)
|
|
92
|
+
)
|