openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,757 @@
|
|
|
1
|
+
"""Validation and preprocessing utilities for symbolic expressions.
|
|
2
|
+
|
|
3
|
+
This module provides preprocessing and validation functions for symbolic expressions
|
|
4
|
+
in trajectory optimization problems. These utilities ensure that expressions are
|
|
5
|
+
well-formed and constraints are properly specified before compilation to solvers.
|
|
6
|
+
|
|
7
|
+
The preprocessing pipeline includes:
|
|
8
|
+
- Shape validation: Ensure all expressions have compatible shapes
|
|
9
|
+
- Variable name validation: Check for unique, non-reserved variable names
|
|
10
|
+
- Constraint validation: Verify constraints appear only at root level
|
|
11
|
+
- Dynamics validation: Check that dynamics match state dimensions
|
|
12
|
+
- Time parameter validation: Validate time configuration
|
|
13
|
+
- Slice assignment: Assign contiguous memory slices to variables
|
|
14
|
+
|
|
15
|
+
These functions are typically called automatically during problem construction,
|
|
16
|
+
but can also be used manually for debugging or custom problem setups.
|
|
17
|
+
|
|
18
|
+
Example:
|
|
19
|
+
Validating expressions before problem construction::
|
|
20
|
+
|
|
21
|
+
import openscvx as ox
|
|
22
|
+
|
|
23
|
+
x = ox.State("x", shape=(3,))
|
|
24
|
+
u = ox.Control("u", shape=(2,))
|
|
25
|
+
|
|
26
|
+
# Build dynamics and constraints
|
|
27
|
+
dynamics = {
|
|
28
|
+
"x": u # Will fail validation - dimension mismatch!
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
# Validate dimensions before creating problem
|
|
32
|
+
from openscvx.symbolic.preprocessing import validate_dynamics_dict_dimensions
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
validate_dynamics_dict_dimensions(dynamics, [x])
|
|
36
|
+
except ValueError as e:
|
|
37
|
+
print(f"Validation error: {e}")
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
|
|
41
|
+
|
|
42
|
+
if TYPE_CHECKING:
|
|
43
|
+
from openscvx.symbolic.time import Time
|
|
44
|
+
|
|
45
|
+
import numpy as np
|
|
46
|
+
|
|
47
|
+
from openscvx.symbolic.expr import (
|
|
48
|
+
CTCS,
|
|
49
|
+
Concat,
|
|
50
|
+
Constant,
|
|
51
|
+
Constraint,
|
|
52
|
+
Control,
|
|
53
|
+
CrossNodeConstraint,
|
|
54
|
+
Expr,
|
|
55
|
+
NodalConstraint,
|
|
56
|
+
State,
|
|
57
|
+
traverse,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def validate_shapes(exprs: Union[Expr, list[Expr]]) -> None:
|
|
62
|
+
"""Validate shapes for a single expression or list of expressions.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
exprs: Single expression or list of expressions to validate
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
ValueError: If any expression has invalid shapes
|
|
69
|
+
"""
|
|
70
|
+
exprs = exprs if isinstance(exprs, (list, tuple)) else [exprs]
|
|
71
|
+
for e in exprs:
|
|
72
|
+
e.check_shape() # will raise ValueError if anything's wrong
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# TODO: (norrisg) allow `traverse` to take a list of visitors, that way we can combine steps
|
|
76
|
+
def validate_variable_names(
|
|
77
|
+
exprs: Iterable[Expr],
|
|
78
|
+
*,
|
|
79
|
+
reserved_prefix: str = "_",
|
|
80
|
+
reserved_names: Set[str] = None,
|
|
81
|
+
) -> None:
|
|
82
|
+
"""Validate variable names for uniqueness and reserved name conflicts.
|
|
83
|
+
|
|
84
|
+
This function ensures that all State and Control variable names are:
|
|
85
|
+
1. Unique across distinct variable instances
|
|
86
|
+
2. Not starting with the reserved prefix (default: "_")
|
|
87
|
+
3. Not colliding with explicitly reserved names
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
exprs: Iterable of expression trees to scan for variables
|
|
91
|
+
reserved_prefix: Prefix that user variables cannot start with (default: "_")
|
|
92
|
+
reserved_names: Set of explicitly reserved names that cannot be used (default: None)
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
ValueError: If any variable name violates uniqueness or reserved name rules
|
|
96
|
+
|
|
97
|
+
Example:
|
|
98
|
+
x1 = ox.State("x", shape=(3,))
|
|
99
|
+
x2 = ox.State("x", shape=(2,)) # Same name, different object
|
|
100
|
+
validate_variable_names([x1 + x2]) # Raises ValueError: Duplicate name 'x'
|
|
101
|
+
|
|
102
|
+
bad = ox.State("_internal", shape=(2,))
|
|
103
|
+
validate_variable_names([bad]) # Raises ValueError: Reserved prefix '_'
|
|
104
|
+
"""
|
|
105
|
+
seen_names = set()
|
|
106
|
+
seen_ids = set()
|
|
107
|
+
reserved = set(reserved_names or ())
|
|
108
|
+
|
|
109
|
+
def visitor(node):
|
|
110
|
+
if not isinstance(node, (State, Control)):
|
|
111
|
+
return
|
|
112
|
+
|
|
113
|
+
node_id = id(node)
|
|
114
|
+
if node_id in seen_ids:
|
|
115
|
+
# we already checked this exact object
|
|
116
|
+
return
|
|
117
|
+
|
|
118
|
+
name = node.name
|
|
119
|
+
|
|
120
|
+
# 1) uniqueness across *different* variables
|
|
121
|
+
if name in seen_names:
|
|
122
|
+
raise ValueError(f"Duplicate variable name: {name!r}")
|
|
123
|
+
|
|
124
|
+
# 2) no leading underscore
|
|
125
|
+
if name.startswith(reserved_prefix):
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Variable name {name!r} is reserved (cannot start with {reserved_prefix!r})"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# 3) no collision with explicit reserved set
|
|
131
|
+
if name in reserved:
|
|
132
|
+
raise ValueError(f"Variable name {name!r} collides with reserved name")
|
|
133
|
+
|
|
134
|
+
seen_names.add(name)
|
|
135
|
+
seen_ids.add(node_id)
|
|
136
|
+
|
|
137
|
+
for e in exprs:
|
|
138
|
+
traverse(e, visitor)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def collect_and_assign_slices(
|
|
142
|
+
states: List[State], controls: List[Control], *, start_index: int = 0
|
|
143
|
+
) -> Tuple[list[State], list[Control]]:
|
|
144
|
+
"""Assign contiguous memory slices to states and controls.
|
|
145
|
+
|
|
146
|
+
This function assigns slice objects to states and controls that determine their
|
|
147
|
+
positions in the flat decision variable vector. Variables can have either:
|
|
148
|
+
- Auto-assigned slices: Automatically assigned contiguously based on order
|
|
149
|
+
- Manual slices: User-specified slices that must be contiguous and non-overlapping
|
|
150
|
+
|
|
151
|
+
If any variables have manual slices, they must:
|
|
152
|
+
- Start at index 0 (or start_index if specified)
|
|
153
|
+
- Be contiguous and non-overlapping
|
|
154
|
+
- Match the variable's flattened dimension
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
states: List of State objects in canonical order
|
|
158
|
+
controls: List of Control objects in canonical order
|
|
159
|
+
start_index: Starting index for slice assignment (default: 0)
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Tuple of (states, controls) with slice attributes assigned
|
|
163
|
+
|
|
164
|
+
Raises:
|
|
165
|
+
ValueError: If manual slices are invalid (wrong size, overlapping, not starting at 0)
|
|
166
|
+
|
|
167
|
+
Example:
|
|
168
|
+
x = ox.State("x", shape=(3,))
|
|
169
|
+
u = ox.Control("u", shape=(2,))
|
|
170
|
+
states, controls = collect_and_assign_slices([x], [u])
|
|
171
|
+
print(x._slice) # slice(0, 3)
|
|
172
|
+
print(u._slice) # slice(0, 2)
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
def assign(vars_list, start_index):
|
|
176
|
+
# split into manual vs auto
|
|
177
|
+
manual = [v for v in vars_list if v._slice is not None]
|
|
178
|
+
auto = [v for v in vars_list if v._slice is None]
|
|
179
|
+
|
|
180
|
+
if manual:
|
|
181
|
+
# 1) shape‐match check
|
|
182
|
+
for v in manual:
|
|
183
|
+
dim = int(np.prod(v.shape))
|
|
184
|
+
sl = v._slice
|
|
185
|
+
if (sl.stop - sl.start) != dim:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"Manual slice for {v.name!r} is length {sl.stop - sl.start}, "
|
|
188
|
+
f"but variable has shape {v.shape} (dim {dim})"
|
|
189
|
+
)
|
|
190
|
+
# sort by the start of their slices
|
|
191
|
+
manual.sort(key=lambda v: v._slice.start)
|
|
192
|
+
# 2a) must start at 0
|
|
193
|
+
if manual[0]._slice.start != start_index:
|
|
194
|
+
raise ValueError("User-defined slices must start at index 0")
|
|
195
|
+
# 2b) check contiguity & no overlaps
|
|
196
|
+
cursor = start_index
|
|
197
|
+
for v in manual:
|
|
198
|
+
sl = v._slice
|
|
199
|
+
dim = int(np.prod(v.shape))
|
|
200
|
+
if sl.start != cursor or sl.stop != cursor + dim:
|
|
201
|
+
raise ValueError(
|
|
202
|
+
f"Manual slice for {v.name!r} must be contiguous and non-overlapping"
|
|
203
|
+
)
|
|
204
|
+
cursor += dim
|
|
205
|
+
offset = cursor
|
|
206
|
+
else:
|
|
207
|
+
offset = start_index
|
|
208
|
+
|
|
209
|
+
# 3) auto-assign the rest
|
|
210
|
+
for v in auto:
|
|
211
|
+
dim = int(np.prod(v.shape))
|
|
212
|
+
v._slice = slice(offset, offset + dim)
|
|
213
|
+
offset += dim
|
|
214
|
+
|
|
215
|
+
# run separately on states (x) and controls (u)
|
|
216
|
+
assign(states, start_index)
|
|
217
|
+
assign(controls, start_index)
|
|
218
|
+
|
|
219
|
+
# Return the collected variables
|
|
220
|
+
return states, controls
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _traverse_with_depth(expr: Expr, visit: Callable[[Expr, int], None], depth: int = 0):
|
|
224
|
+
"""Depth-first traversal of an expression tree with depth tracking.
|
|
225
|
+
|
|
226
|
+
Internal helper function that extends the standard traverse function to track
|
|
227
|
+
the depth of each node in the tree. Used for constraint validation.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
expr: Root expression node to start traversal from
|
|
231
|
+
visit: Callback function applied to each (node, depth) pair during traversal
|
|
232
|
+
depth: Current depth level (default: 0)
|
|
233
|
+
"""
|
|
234
|
+
visit(expr, depth)
|
|
235
|
+
for child in expr.children():
|
|
236
|
+
_traverse_with_depth(child, visit, depth + 1)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def validate_constraints_at_root(exprs: Union[Expr, list[Expr]]):
|
|
240
|
+
"""Validate that constraints only appear at the root level of expression trees.
|
|
241
|
+
|
|
242
|
+
Constraints and constraint wrappers (CTCS, NodalConstraint, CrossNodeConstraint)
|
|
243
|
+
must only appear as top-level expressions, not nested within other expressions.
|
|
244
|
+
However, constraints inside constraint wrappers are allowed (e.g., the constraint
|
|
245
|
+
inside CTCS(x <= 5)).
|
|
246
|
+
|
|
247
|
+
This ensures constraints are properly processed during problem compilation and
|
|
248
|
+
prevents ambiguous constraint specifications.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
exprs: Single expression or list of expressions to validate
|
|
252
|
+
|
|
253
|
+
Raises:
|
|
254
|
+
ValueError: If any constraint or constraint wrapper is found at depth > 0
|
|
255
|
+
|
|
256
|
+
Example:
|
|
257
|
+
x = ox.State("x", shape=(3,))
|
|
258
|
+
constraint = x <= 5
|
|
259
|
+
validate_constraints_at_root([constraint]) # OK - constraint at root
|
|
260
|
+
|
|
261
|
+
bad_expr = ox.Sum(x <= 5) # Constraint nested inside Sum
|
|
262
|
+
validate_constraints_at_root([bad_expr]) # Raises ValueError
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
# Define constraint wrappers that must also be at root level
|
|
266
|
+
CONSTRAINT_WRAPPERS = (CTCS, NodalConstraint, CrossNodeConstraint)
|
|
267
|
+
|
|
268
|
+
# normalize to list
|
|
269
|
+
expr_list = exprs if isinstance(exprs, (list, tuple)) else [exprs]
|
|
270
|
+
|
|
271
|
+
for expr in expr_list:
|
|
272
|
+
|
|
273
|
+
def visit(node: Expr, depth: int):
|
|
274
|
+
if depth > 0:
|
|
275
|
+
if isinstance(node, CONSTRAINT_WRAPPERS):
|
|
276
|
+
raise ValueError(
|
|
277
|
+
f"Nested constraint wrapper found at depth {depth!r}: {node!r}; "
|
|
278
|
+
"constraint wrappers must only appear as top-level roots"
|
|
279
|
+
)
|
|
280
|
+
elif isinstance(node, Constraint):
|
|
281
|
+
raise ValueError(
|
|
282
|
+
f"Nested Constraint found at depth {depth!r}: {node!r}; "
|
|
283
|
+
"constraints must only appear as top-level roots"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# If this is a constraint wrapper, don't validate its children
|
|
287
|
+
# (we allow constraints inside constraint wrappers)
|
|
288
|
+
if isinstance(node, CONSTRAINT_WRAPPERS):
|
|
289
|
+
return # Skip traversing children
|
|
290
|
+
|
|
291
|
+
# Otherwise, continue traversing children
|
|
292
|
+
for child in node.children():
|
|
293
|
+
visit(child, depth + 1)
|
|
294
|
+
|
|
295
|
+
# Start traversal
|
|
296
|
+
visit(expr, 0)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def validate_and_normalize_constraint_nodes(exprs: Union[Expr, list[Expr]], n_nodes: int):
|
|
300
|
+
"""Validate and normalize constraint node specifications.
|
|
301
|
+
|
|
302
|
+
This function validates and normalizes node specifications for constraint wrappers:
|
|
303
|
+
|
|
304
|
+
For NodalConstraint:
|
|
305
|
+
- nodes should be a list of specific node indices: [2, 4, 6, 8]
|
|
306
|
+
- Validates all nodes are within the valid range [0, n_nodes)
|
|
307
|
+
|
|
308
|
+
For CTCS (Continuous-Time Constraint Satisfaction) constraints:
|
|
309
|
+
- nodes should be a tuple of (start, end): (0, 10)
|
|
310
|
+
- None is replaced with (0, n_nodes) to apply over entire trajectory
|
|
311
|
+
- Validation ensures tuple has exactly 2 elements and start < end
|
|
312
|
+
- Validates indices are within trajectory bounds
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
exprs: Single expression or list of expressions to validate
|
|
316
|
+
n_nodes: Total number of nodes in the trajectory
|
|
317
|
+
|
|
318
|
+
Raises:
|
|
319
|
+
ValueError: If node specifications are invalid (out of range, malformed, etc.)
|
|
320
|
+
|
|
321
|
+
Example:
|
|
322
|
+
x = ox.State("x", shape=(3,))
|
|
323
|
+
constraint = (x <= 5).at([0, 10, 20]) # NodalConstraint
|
|
324
|
+
validate_and_normalize_constraint_nodes([constraint], n_nodes=50) # OK
|
|
325
|
+
|
|
326
|
+
ctcs_constraint = (x <= 5).over((0, 100)) # CTCS
|
|
327
|
+
validate_and_normalize_constraint_nodes([ctcs_constraint], n_nodes=50)
|
|
328
|
+
# Raises ValueError: Range exceeds trajectory length
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
# Normalize to list
|
|
332
|
+
expr_list = exprs if isinstance(exprs, (list, tuple)) else [exprs]
|
|
333
|
+
|
|
334
|
+
for expr in expr_list:
|
|
335
|
+
if isinstance(expr, CTCS):
|
|
336
|
+
# CTCS constraint validation (already done in __init__, but normalize None)
|
|
337
|
+
if expr.nodes is None:
|
|
338
|
+
expr.nodes = (0, n_nodes)
|
|
339
|
+
elif expr.nodes[0] >= n_nodes or expr.nodes[1] > n_nodes:
|
|
340
|
+
raise ValueError(
|
|
341
|
+
f"CTCS node range {expr.nodes} exceeds trajectory length {n_nodes}"
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
elif isinstance(expr, NodalConstraint):
|
|
345
|
+
# NodalConstraint validation - nodes are already validated in __init__
|
|
346
|
+
# Just need to check they're within trajectory range
|
|
347
|
+
for node in expr.nodes:
|
|
348
|
+
if node < 0 or node >= n_nodes:
|
|
349
|
+
raise ValueError(f"NodalConstraint node {node} is out of range [0, {n_nodes})")
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def validate_cross_node_constraint(cross_node_constraint, n_nodes: int) -> None:
|
|
353
|
+
"""Validate cross-node constraint bounds and variable consistency.
|
|
354
|
+
|
|
355
|
+
This function performs two validations in a single tree traversal:
|
|
356
|
+
|
|
357
|
+
1. **Bounds checking**: Ensures all NodeReference indices are within [0, n_nodes).
|
|
358
|
+
Cross-node constraints reference fixed trajectory nodes (e.g., position.at(5)),
|
|
359
|
+
and this validates those indices are valid. Negative indices are normalized
|
|
360
|
+
(e.g., -1 becomes n_nodes-1) before checking.
|
|
361
|
+
|
|
362
|
+
2. **Variable consistency**: Ensures that if ANY variable uses .at(), then ALL
|
|
363
|
+
state/control variables must use .at(). Mixing causes shape mismatches during
|
|
364
|
+
lowering because:
|
|
365
|
+
- Variables with .at(k) extract single-node values: X[k, :] → shape (n_x,)
|
|
366
|
+
- Variables without .at() expect full trajectory: X[:, :] → shape (N, n_x)
|
|
367
|
+
|
|
368
|
+
Args:
|
|
369
|
+
cross_node_constraint: The CrossNodeConstraint to validate
|
|
370
|
+
n_nodes: Total number of trajectory nodes
|
|
371
|
+
|
|
372
|
+
Raises:
|
|
373
|
+
ValueError: If any NodeReference accesses nodes outside [0, n_nodes)
|
|
374
|
+
ValueError: If constraint mixes .at() and non-.at() variables
|
|
375
|
+
|
|
376
|
+
Example:
|
|
377
|
+
Valid cross-node constraint:
|
|
378
|
+
|
|
379
|
+
from openscvx.symbolic.expr import CrossNodeConstraint
|
|
380
|
+
|
|
381
|
+
position = State("pos", shape=(3,))
|
|
382
|
+
|
|
383
|
+
# Valid: all variables use .at(), indices in bounds
|
|
384
|
+
constraint = CrossNodeConstraint(position.at(5) - position.at(4) <= 0.1)
|
|
385
|
+
validate_cross_node_constraint(constraint, n_nodes=10) # OK
|
|
386
|
+
|
|
387
|
+
Invalid - out of bounds:
|
|
388
|
+
|
|
389
|
+
# Invalid: node 10 is out of bounds for n_nodes=10
|
|
390
|
+
bad_bounds = CrossNodeConstraint(position.at(0) == position.at(10))
|
|
391
|
+
validate_cross_node_constraint(bad_bounds, n_nodes=10) # Raises ValueError
|
|
392
|
+
|
|
393
|
+
Invalid - mixed .at() usage:
|
|
394
|
+
|
|
395
|
+
velocity = State("vel", shape=(3,))
|
|
396
|
+
# Invalid: position uses .at(), velocity doesn't
|
|
397
|
+
bad_mixed = CrossNodeConstraint(position.at(5) - velocity <= 0.1)
|
|
398
|
+
validate_cross_node_constraint(bad_mixed, n_nodes=10) # Raises ValueError
|
|
399
|
+
"""
|
|
400
|
+
from openscvx.symbolic.expr import Control, CrossNodeConstraint, NodeReference, State
|
|
401
|
+
|
|
402
|
+
if not isinstance(cross_node_constraint, CrossNodeConstraint):
|
|
403
|
+
raise TypeError(
|
|
404
|
+
f"Expected CrossNodeConstraint, got {type(cross_node_constraint).__name__}. "
|
|
405
|
+
f"Bare constraints with NodeReferences should be wrapped in CrossNodeConstraint "
|
|
406
|
+
f"by separate_constraints() before validation."
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
constraint = cross_node_constraint.constraint
|
|
410
|
+
|
|
411
|
+
# Collect information in a single traversal
|
|
412
|
+
node_refs = [] # List of (node_idx, normalized_idx) tuples
|
|
413
|
+
unwrapped_vars = [] # List of variable names without .at()
|
|
414
|
+
|
|
415
|
+
def traverse(expr):
|
|
416
|
+
if isinstance(expr, NodeReference):
|
|
417
|
+
# Normalize negative indices
|
|
418
|
+
idx = expr.node_idx
|
|
419
|
+
normalized_idx = idx if idx >= 0 else n_nodes + idx
|
|
420
|
+
node_refs.append((idx, normalized_idx))
|
|
421
|
+
# Don't traverse into children - NodeReference wraps the variable
|
|
422
|
+
return
|
|
423
|
+
|
|
424
|
+
if isinstance(expr, (State, Control)):
|
|
425
|
+
# Found a bare State/Control not wrapped in NodeReference
|
|
426
|
+
unwrapped_vars.append(expr.name)
|
|
427
|
+
return
|
|
428
|
+
|
|
429
|
+
# Recurse on children
|
|
430
|
+
for child in expr.children():
|
|
431
|
+
traverse(child)
|
|
432
|
+
|
|
433
|
+
# Traverse the constraint expression (both sides)
|
|
434
|
+
traverse(constraint.lhs)
|
|
435
|
+
traverse(constraint.rhs)
|
|
436
|
+
|
|
437
|
+
# Check 1: Bounds validation
|
|
438
|
+
for orig_idx, normalized_idx in node_refs:
|
|
439
|
+
if normalized_idx < 0 or normalized_idx >= n_nodes:
|
|
440
|
+
raise ValueError(
|
|
441
|
+
f"Cross-node constraint references invalid node index {orig_idx}. "
|
|
442
|
+
f"Node indices must be in range [0, {n_nodes}) "
|
|
443
|
+
f"(or negative indices in range [-{n_nodes}, -1]). "
|
|
444
|
+
f"Constraint: {constraint}"
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
# Check 2: Variable consistency - if we have NodeReferences, all vars must use .at()
|
|
448
|
+
if node_refs and unwrapped_vars:
|
|
449
|
+
raise ValueError(
|
|
450
|
+
f"Cross-node constraint contains NodeReferences (variables with .at(k)) "
|
|
451
|
+
f"but also has variables without .at(): {unwrapped_vars}. "
|
|
452
|
+
f"All state/control variables in cross-node constraints must use .at(k). "
|
|
453
|
+
f"For example, if you use 'position.at(5)', you must also use 'velocity.at(4)' "
|
|
454
|
+
f"instead of just 'velocity'. "
|
|
455
|
+
f"Constraint: {constraint}"
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def validate_dynamics_dimension(
|
|
460
|
+
dynamics_expr: Union[Expr, list[Expr]], states: Union[State, list[State]]
|
|
461
|
+
) -> None:
|
|
462
|
+
"""Validate that dynamics expression dimensions match state dimensions.
|
|
463
|
+
|
|
464
|
+
Ensures that the total dimension of all dynamics expressions matches the total
|
|
465
|
+
dimension of all states. Each dynamics expression must be a 1D vector, and their
|
|
466
|
+
combined dimension must equal the sum of all state dimensions.
|
|
467
|
+
|
|
468
|
+
This is essential for ensuring the ODE system x_dot = f(x, u, t) is well-formed.
|
|
469
|
+
|
|
470
|
+
Args:
|
|
471
|
+
dynamics_expr: Single dynamics expression or list of dynamics expressions.
|
|
472
|
+
Combined, they represent x_dot = f(x, u, t) for all states.
|
|
473
|
+
states: Single state variable or list of state variables that the dynamics describe.
|
|
474
|
+
|
|
475
|
+
Raises:
|
|
476
|
+
ValueError: If dimensions don't match or if any dynamics is not a 1D vector
|
|
477
|
+
|
|
478
|
+
Example:
|
|
479
|
+
x = ox.State("x", shape=(3,))
|
|
480
|
+
y = ox.State("y", shape=(2,))
|
|
481
|
+
dynamics = ox.Concat(x * 2, y + 1) # Shape (5,) - matches total state dim
|
|
482
|
+
validate_dynamics_dimension(dynamics, [x, y]) # OK
|
|
483
|
+
|
|
484
|
+
bad_dynamics = x # Shape (3,) - doesn't match total dim of 5
|
|
485
|
+
validate_dynamics_dimension(bad_dynamics, [x, y]) # Raises ValueError
|
|
486
|
+
"""
|
|
487
|
+
# Normalize inputs to lists
|
|
488
|
+
dynamics_list = dynamics_expr if isinstance(dynamics_expr, (list, tuple)) else [dynamics_expr]
|
|
489
|
+
states_list = states if isinstance(states, (list, tuple)) else [states]
|
|
490
|
+
|
|
491
|
+
# Calculate total state dimension
|
|
492
|
+
total_state_dim = sum(int(np.prod(state.shape)) for state in states_list)
|
|
493
|
+
|
|
494
|
+
# Validate each dynamics expression and calculate total dynamics dimension
|
|
495
|
+
total_dynamics_dim = 0
|
|
496
|
+
|
|
497
|
+
for i, dyn_expr in enumerate(dynamics_list):
|
|
498
|
+
# Get the shape of this dynamics expression
|
|
499
|
+
dynamics_shape = dyn_expr.check_shape()
|
|
500
|
+
|
|
501
|
+
# Dynamics should be a 1D vector
|
|
502
|
+
if len(dynamics_shape) != 1:
|
|
503
|
+
prefix = f"Dynamics expression {i}" if len(dynamics_list) > 1 else "Dynamics expression"
|
|
504
|
+
raise ValueError(
|
|
505
|
+
f"{prefix} must be 1-dimensional (vector), but got shape {dynamics_shape}"
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
total_dynamics_dim += dynamics_shape[0]
|
|
509
|
+
|
|
510
|
+
# Check that total dynamics dimension matches total state dimension
|
|
511
|
+
if total_dynamics_dim != total_state_dim:
|
|
512
|
+
if len(dynamics_list) == 1:
|
|
513
|
+
raise ValueError(
|
|
514
|
+
f"Dynamics dimension mismatch: dynamics has dimension {total_dynamics_dim}, "
|
|
515
|
+
f"but total state dimension is {total_state_dim}. "
|
|
516
|
+
f"States: {[(s.name, s.shape) for s in states_list]}"
|
|
517
|
+
)
|
|
518
|
+
else:
|
|
519
|
+
dynamics_dims = [dyn.check_shape()[0] for dyn in dynamics_list]
|
|
520
|
+
raise ValueError(
|
|
521
|
+
f"Dynamics dimension mismatch: {len(dynamics_list)} dynamics expressions "
|
|
522
|
+
f"have combined dimension {total_dynamics_dim} {dynamics_dims}, "
|
|
523
|
+
f"but total state dimension is {total_state_dim}. "
|
|
524
|
+
f"States: {[(s.name, s.shape) for s in states_list]}"
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def validate_dynamics_dict(
|
|
529
|
+
dynamics: Dict[str, Expr],
|
|
530
|
+
states: List[State],
|
|
531
|
+
byof_dynamics: Optional[Dict[str, callable]] = None,
|
|
532
|
+
) -> None:
|
|
533
|
+
"""Validate that dynamics dictionary keys match state names exactly.
|
|
534
|
+
|
|
535
|
+
Ensures that the dynamics dictionary (combined with optional byof dynamics) has
|
|
536
|
+
exactly the same keys as the state names, with no missing states, no extra keys,
|
|
537
|
+
and no overlap between symbolic and byof dynamics.
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
dynamics: Dictionary mapping state names to their dynamics expressions
|
|
541
|
+
states: List of State objects
|
|
542
|
+
byof_dynamics: Optional dictionary mapping state names to raw JAX functions.
|
|
543
|
+
States in byof_dynamics should NOT appear in dynamics dict.
|
|
544
|
+
|
|
545
|
+
Raises:
|
|
546
|
+
ValueError: If there's a mismatch between state names and dynamics keys,
|
|
547
|
+
or if a state appears in both dynamics and byof_dynamics.
|
|
548
|
+
|
|
549
|
+
Example:
|
|
550
|
+
x = ox.State("x", shape=(3,))
|
|
551
|
+
y = ox.State("y", shape=(2,))
|
|
552
|
+
dynamics = {"x": x * 2, "y": y + 1}
|
|
553
|
+
validate_dynamics_dict(dynamics, [x, y]) # OK
|
|
554
|
+
|
|
555
|
+
bad_dynamics = {"x": x * 2} # Missing "y"
|
|
556
|
+
validate_dynamics_dict(bad_dynamics, [x, y]) # Raises ValueError
|
|
557
|
+
|
|
558
|
+
# With byof_dynamics (expert user mode)
|
|
559
|
+
dynamics = {"x": x * 2} # Only symbolic for x
|
|
560
|
+
byof_dynamics = {"y": some_jax_fn} # Raw JAX for y
|
|
561
|
+
validate_dynamics_dict(dynamics, [x, y], byof_dynamics) # OK
|
|
562
|
+
"""
|
|
563
|
+
state_names_set = set(state.name for state in states)
|
|
564
|
+
symbolic_keys = set(dynamics.keys())
|
|
565
|
+
byof_keys = set(byof_dynamics.keys()) if byof_dynamics else set()
|
|
566
|
+
|
|
567
|
+
# Check for overlap - a state can't be defined in both
|
|
568
|
+
overlap = symbolic_keys & byof_keys
|
|
569
|
+
if overlap:
|
|
570
|
+
raise ValueError(
|
|
571
|
+
f"States defined in both symbolic and byof dynamics: {overlap}\n"
|
|
572
|
+
"Each state must have dynamics in exactly one place."
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
# Check coverage - all states must be covered
|
|
576
|
+
covered = symbolic_keys | byof_keys
|
|
577
|
+
missing = state_names_set - covered
|
|
578
|
+
extra = covered - state_names_set
|
|
579
|
+
|
|
580
|
+
if missing or extra:
|
|
581
|
+
error_msg = "Mismatch between state names and dynamics keys.\n"
|
|
582
|
+
if missing:
|
|
583
|
+
error_msg += f" States missing from dynamics: {missing}\n"
|
|
584
|
+
if extra:
|
|
585
|
+
error_msg += f" Extra keys in dynamics: {extra}\n"
|
|
586
|
+
raise ValueError(error_msg)
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def validate_dynamics_dict_dimensions(dynamics: Dict[str, Expr], states: List[State]) -> None:
|
|
590
|
+
"""Validate that each dynamics expression matches its corresponding state shape.
|
|
591
|
+
|
|
592
|
+
For dictionary-based dynamics specification, ensures that each state's dynamics
|
|
593
|
+
expression has the same shape as the state itself. This validates that each
|
|
594
|
+
component of x_dot = f(x, u, t) has the correct dimension.
|
|
595
|
+
|
|
596
|
+
Scalars are normalized to shape (1,) for comparison, matching Concat behavior.
|
|
597
|
+
|
|
598
|
+
Args:
|
|
599
|
+
dynamics: Dictionary mapping state names to their dynamics expressions
|
|
600
|
+
states: List of State objects
|
|
601
|
+
|
|
602
|
+
Raises:
|
|
603
|
+
ValueError: If any dynamics expression dimension doesn't match its state shape
|
|
604
|
+
|
|
605
|
+
Example:
|
|
606
|
+
x = ox.State("x", shape=(3,))
|
|
607
|
+
y = ox.State("y", shape=(2,))
|
|
608
|
+
u = ox.Control("u", shape=(3,))
|
|
609
|
+
dynamics = {"x": u, "y": y + 1}
|
|
610
|
+
validate_dynamics_dict_dimensions(dynamics, [x, y]) # OK
|
|
611
|
+
|
|
612
|
+
bad_dynamics = {"x": u, "y": u} # y dynamics has wrong shape
|
|
613
|
+
validate_dynamics_dict_dimensions(bad_dynamics, [x, y]) # Raises ValueError
|
|
614
|
+
"""
|
|
615
|
+
|
|
616
|
+
def normalize_scalars(shape: Tuple[int, ...]) -> Tuple[int, ...]:
|
|
617
|
+
"""Normalize shape: scalar () becomes (1,)"""
|
|
618
|
+
return (1,) if len(shape) == 0 else shape
|
|
619
|
+
|
|
620
|
+
for state in states:
|
|
621
|
+
dyn_expr = dynamics[state.name]
|
|
622
|
+
expected_shape = state.shape
|
|
623
|
+
|
|
624
|
+
# Handle raw Python numbers (which will be converted to Constant later)
|
|
625
|
+
if isinstance(dyn_expr, (int, float)):
|
|
626
|
+
actual_shape = () # Scalars have shape ()
|
|
627
|
+
else:
|
|
628
|
+
# Compute the shape of the dynamics expression
|
|
629
|
+
actual_shape = dyn_expr.check_shape()
|
|
630
|
+
|
|
631
|
+
# Normalize both shapes for comparison (consistent with Concat behavior)
|
|
632
|
+
if normalize_scalars(actual_shape) != normalize_scalars(expected_shape):
|
|
633
|
+
raise ValueError(
|
|
634
|
+
f"Dynamics for state '{state.name}' has shape {actual_shape}, "
|
|
635
|
+
f"but state has shape {expected_shape}"
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
def validate_time_parameters(
|
|
640
|
+
states: List[State],
|
|
641
|
+
time: "Time",
|
|
642
|
+
) -> Tuple[
|
|
643
|
+
bool,
|
|
644
|
+
Union[float, tuple, None],
|
|
645
|
+
Union[float, tuple, None],
|
|
646
|
+
float,
|
|
647
|
+
Union[float, None],
|
|
648
|
+
Union[float, None],
|
|
649
|
+
]:
|
|
650
|
+
"""Validate time parameter usage and configuration.
|
|
651
|
+
|
|
652
|
+
There are two valid approaches for handling time in trajectory optimization:
|
|
653
|
+
|
|
654
|
+
1. Auto-create time (recommended): Don't include "time" in states, provide Time object.
|
|
655
|
+
The time state is automatically created and managed.
|
|
656
|
+
|
|
657
|
+
2. User-provided time (advanced): Include a "time" State in states. The Time object
|
|
658
|
+
is ignored and the user has full control over time dynamics.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
states: List of State objects
|
|
662
|
+
time: Time configuration object (required, but ignored if time state exists)
|
|
663
|
+
|
|
664
|
+
Returns:
|
|
665
|
+
Tuple of (has_time_state, time_initial, time_final, time_derivative, time_min, time_max):
|
|
666
|
+
- has_time_state: True if user provided a time state
|
|
667
|
+
- time_initial: Initial time value (None if user-provided time)
|
|
668
|
+
- time_final: Final time value (None if user-provided time)
|
|
669
|
+
- time_derivative: Always 1.0 for auto-created time (None if user-provided)
|
|
670
|
+
- time_min: Minimum time bound (None if user-provided)
|
|
671
|
+
- time_max: Maximum time bound (None if user-provided)
|
|
672
|
+
|
|
673
|
+
Raises:
|
|
674
|
+
ValueError: If Time object is not provided or has invalid type
|
|
675
|
+
|
|
676
|
+
Example:
|
|
677
|
+
# Approach 1: Auto-create time
|
|
678
|
+
x = ox.State("x", shape=(3,))
|
|
679
|
+
time_obj = ox.Time(initial=0.0, final=10.0)
|
|
680
|
+
validate_time_parameters([x], time_obj)
|
|
681
|
+
(False, 0.0, 10.0, 1.0, None, None)
|
|
682
|
+
|
|
683
|
+
# Approach 2: User-provided time
|
|
684
|
+
x = ox.State("x", shape=(3,))
|
|
685
|
+
time_state = ox.State("time", shape=())
|
|
686
|
+
validate_time_parameters([x, time_state], time_obj)
|
|
687
|
+
(True, None, None, None, None, None)
|
|
688
|
+
"""
|
|
689
|
+
from openscvx.symbolic.time import Time
|
|
690
|
+
|
|
691
|
+
if not isinstance(time, Time):
|
|
692
|
+
raise ValueError(f"Expected Time object, but got {type(time).__name__}")
|
|
693
|
+
|
|
694
|
+
has_time_state = any(state.name == "time" for state in states)
|
|
695
|
+
|
|
696
|
+
if has_time_state:
|
|
697
|
+
# Approach 2: User-provided time state
|
|
698
|
+
# Time object is provided but ignored - user handles everything via State
|
|
699
|
+
# Return None for all time parameters since user handles everything
|
|
700
|
+
return True, None, None, None, None, None
|
|
701
|
+
else:
|
|
702
|
+
# Approach 1: Auto-create time state
|
|
703
|
+
# Extract values from Time object
|
|
704
|
+
time_initial = time.initial
|
|
705
|
+
time_final = time.final
|
|
706
|
+
time_derivative = 1.0 # Always 1.0 when using Time object
|
|
707
|
+
time_min = time.min
|
|
708
|
+
time_max = time.max
|
|
709
|
+
|
|
710
|
+
return False, time_initial, time_final, time_derivative, time_min, time_max
|
|
711
|
+
|
|
712
|
+
|
|
713
|
+
def convert_dynamics_dict_to_expr(
|
|
714
|
+
dynamics: Dict[str, Expr], states: List[State]
|
|
715
|
+
) -> Tuple[Dict[str, Expr], Expr]:
|
|
716
|
+
"""Convert dynamics dictionary to concatenated expression in canonical order.
|
|
717
|
+
|
|
718
|
+
Converts a dictionary-based dynamics specification to a single concatenated expression
|
|
719
|
+
that represents the full ODE system x_dot = f(x, u, t). The dynamics are ordered
|
|
720
|
+
according to the states list to ensure consistent variable ordering.
|
|
721
|
+
|
|
722
|
+
This function also normalizes scalar values (int, float) to Constant expressions.
|
|
723
|
+
|
|
724
|
+
Args:
|
|
725
|
+
dynamics: Dictionary mapping state names to their dynamics expressions
|
|
726
|
+
states: List of State objects defining the canonical order
|
|
727
|
+
|
|
728
|
+
Returns:
|
|
729
|
+
Tuple of:
|
|
730
|
+
- Updated dynamics dictionary (with scalars converted to Constant expressions)
|
|
731
|
+
- Concatenated dynamics expression ordered by states list
|
|
732
|
+
|
|
733
|
+
Example:
|
|
734
|
+
Convert dynamics dict to a single expression:
|
|
735
|
+
|
|
736
|
+
x = ox.State("x", shape=(3,))
|
|
737
|
+
y = ox.State("y", shape=(2,))
|
|
738
|
+
dynamics_dict = {"x": x * 2, "y": 1.0} # Scalar for y
|
|
739
|
+
converted_dict, concat_expr = convert_dynamics_dict_to_expr(
|
|
740
|
+
dynamics_dict, [x, y]
|
|
741
|
+
)
|
|
742
|
+
# converted_dict["y"] is now Constant(1.0)
|
|
743
|
+
# concat_expr is Concat(x * 2, Constant(1.0))
|
|
744
|
+
"""
|
|
745
|
+
# Create a copy to avoid mutating the input
|
|
746
|
+
dynamics_converted = dict(dynamics)
|
|
747
|
+
|
|
748
|
+
# Convert scalar values to Constant expressions
|
|
749
|
+
for state_name, dyn_expr in dynamics_converted.items():
|
|
750
|
+
if isinstance(dyn_expr, (int, float)):
|
|
751
|
+
dynamics_converted[state_name] = Constant(dyn_expr)
|
|
752
|
+
|
|
753
|
+
# Create concatenated expression ordered by states list
|
|
754
|
+
dynamics_exprs = [dynamics_converted[state.name] for state in states]
|
|
755
|
+
dynamics_concat = Concat(*dynamics_exprs)
|
|
756
|
+
|
|
757
|
+
return dynamics_converted, dynamics_concat
|