openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,630 @@
|
|
|
1
|
+
"""State and dynamics augmentation for continuous-time constraint satisfaction.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for augmenting trajectory optimization problems with
|
|
4
|
+
additional states and dynamics to handle continuous-time constraint satisfaction (CTCS).
|
|
5
|
+
The CTCS method enforces path constraints continuously along the trajectory rather than
|
|
6
|
+
just at discretization nodes.
|
|
7
|
+
|
|
8
|
+
Key functionality:
|
|
9
|
+
- CTCS constraint grouping: Sort and group CTCS constraints by time intervals
|
|
10
|
+
- Constraint separation: Separate CTCS, nodal, and convex constraints
|
|
11
|
+
- Vector decomposition: Decompose vector constraints into scalar components
|
|
12
|
+
- Time augmentation: Add time state with appropriate dynamics and constraints
|
|
13
|
+
- CTCS dynamics augmentation: Add augmented states and time dilation control
|
|
14
|
+
|
|
15
|
+
The augmentation process transforms the original dynamics x_dot = f(x, u) into an
|
|
16
|
+
augmented system with additional states for constraint satisfaction and time dilation.
|
|
17
|
+
|
|
18
|
+
Architecture:
|
|
19
|
+
The CTCS method works by:
|
|
20
|
+
|
|
21
|
+
1. Grouping constraints by time interval and assigning index (idx)
|
|
22
|
+
2. Creating augmented states (one per constraint group)
|
|
23
|
+
3. Adding penalty dynamics: aug_dot = penalty(constraint_violation)
|
|
24
|
+
4. Adding time dilation control to slow down near constraint boundaries
|
|
25
|
+
|
|
26
|
+
Example:
|
|
27
|
+
Augmenting dynamics with CTCS constraints::
|
|
28
|
+
|
|
29
|
+
import openscvx as ox
|
|
30
|
+
|
|
31
|
+
# Define problem
|
|
32
|
+
x = ox.State("x", shape=(3,))
|
|
33
|
+
u = ox.Control("u", shape=(2,))
|
|
34
|
+
|
|
35
|
+
# Create dynamics
|
|
36
|
+
xdot = u @ A # Some dynamics expression
|
|
37
|
+
|
|
38
|
+
# Define path constraint
|
|
39
|
+
path_constraint = (ox.Norm(x) <= 1.0).over((0, 50)) # CTCS constraint
|
|
40
|
+
|
|
41
|
+
# Augment dynamics with CTCS
|
|
42
|
+
from openscvx.symbolic.augmentation import augment_dynamics_with_ctcs
|
|
43
|
+
|
|
44
|
+
xdot_aug, states_aug, controls_aug = augment_dynamics_with_ctcs(
|
|
45
|
+
xdot=xdot,
|
|
46
|
+
states=[x],
|
|
47
|
+
controls=[u],
|
|
48
|
+
constraints_ctcs=[path_constraint],
|
|
49
|
+
N=50
|
|
50
|
+
)
|
|
51
|
+
# xdot_aug now includes augmented state dynamics
|
|
52
|
+
# states_aug includes original states + augmented states
|
|
53
|
+
# controls_aug includes original controls + time dilation
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
from typing import Dict, List, Optional, Tuple
|
|
57
|
+
|
|
58
|
+
import numpy as np
|
|
59
|
+
|
|
60
|
+
from openscvx.symbolic.constraint_set import ConstraintSet
|
|
61
|
+
from openscvx.symbolic.expr import (
|
|
62
|
+
CTCS,
|
|
63
|
+
Add,
|
|
64
|
+
Concat,
|
|
65
|
+
Constraint,
|
|
66
|
+
CrossNodeConstraint,
|
|
67
|
+
Expr,
|
|
68
|
+
Index,
|
|
69
|
+
NodalConstraint,
|
|
70
|
+
)
|
|
71
|
+
from openscvx.symbolic.expr.control import Control
|
|
72
|
+
from openscvx.symbolic.expr.state import State
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def sort_ctcs_constraints(
|
|
76
|
+
constraints_ctcs: List[CTCS],
|
|
77
|
+
) -> Tuple[List[CTCS], List[Tuple[int, int]], int]:
|
|
78
|
+
"""Sort and group CTCS constraints by time interval and assign indices.
|
|
79
|
+
|
|
80
|
+
Groups CTCS constraints by their time intervals (nodes) and assigns a unique
|
|
81
|
+
index (idx) to each group. Constraints with the same time interval can share
|
|
82
|
+
an augmented state (same idx), while constraints with different intervals must
|
|
83
|
+
have different augmented states.
|
|
84
|
+
|
|
85
|
+
Grouping rules:
|
|
86
|
+
- Constraints with the same node interval can share an idx
|
|
87
|
+
- Constraints with different node intervals must have different idx values
|
|
88
|
+
- idx values must form a contiguous block starting from 0
|
|
89
|
+
- Unspecified idx values are automatically assigned
|
|
90
|
+
- User-specified idx values are validated for consistency
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
constraints_ctcs: List of CTCS constraints to sort and group
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Tuple of:
|
|
97
|
+
- List of CTCS constraints with idx assigned to each
|
|
98
|
+
- List of node intervals (start, end) in ascending idx order
|
|
99
|
+
- Number of augmented states needed (number of unique idx values)
|
|
100
|
+
|
|
101
|
+
Raises:
|
|
102
|
+
ValueError: If user-specified idx values are inconsistent or non-contiguous
|
|
103
|
+
|
|
104
|
+
Example:
|
|
105
|
+
Sort CTCS constraints by interval and index:
|
|
106
|
+
|
|
107
|
+
constraint1 = (x <= 5).over((0, 50)) # Auto-assigned idx
|
|
108
|
+
constraint2 = (y <= 10).over((0, 50)) # Same interval, same idx
|
|
109
|
+
constraint3 = (z <= 15).over((20, 80)) # Different interval, different idx
|
|
110
|
+
sorted_ctcs, intervals, n_aug = sort_ctcs_constraints([c1, c2, c3])
|
|
111
|
+
# constraint1.idx = 0, constraint2.idx = 0, constraint3.idx = 1
|
|
112
|
+
# intervals = [(0, 50), (20, 80)]
|
|
113
|
+
# n_aug = 2
|
|
114
|
+
"""
|
|
115
|
+
idx_to_nodes: Dict[int, Tuple[int, int]] = {}
|
|
116
|
+
next_idx = 0
|
|
117
|
+
|
|
118
|
+
for c in constraints_ctcs:
|
|
119
|
+
key = c.nodes
|
|
120
|
+
|
|
121
|
+
if c.idx is not None:
|
|
122
|
+
# User supplied an identifier: ensure it always points to the same interval
|
|
123
|
+
if c.idx in idx_to_nodes:
|
|
124
|
+
if idx_to_nodes[c.idx] != key:
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"idx={c.idx} was first used with interval={idx_to_nodes[c.idx]}, "
|
|
127
|
+
f"but now you gave it interval={key}"
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
# When idx is explicitly provided, always create a separate group
|
|
131
|
+
# even if nodes are the same - this allows multiple constraint groups
|
|
132
|
+
# with the same node interval but different idx values
|
|
133
|
+
idx_to_nodes[c.idx] = key
|
|
134
|
+
else:
|
|
135
|
+
# No identifier: see if this interval already has one
|
|
136
|
+
for existing_id, nodes in idx_to_nodes.items():
|
|
137
|
+
if nodes == key:
|
|
138
|
+
c.idx = existing_id
|
|
139
|
+
break
|
|
140
|
+
else:
|
|
141
|
+
# Brand-new interval: pick the next free auto-id
|
|
142
|
+
while next_idx in idx_to_nodes:
|
|
143
|
+
next_idx += 1
|
|
144
|
+
c.idx = next_idx
|
|
145
|
+
idx_to_nodes[next_idx] = key
|
|
146
|
+
next_idx += 1
|
|
147
|
+
|
|
148
|
+
# Validate that idx values form a contiguous block starting from 0
|
|
149
|
+
ordered_ids = sorted(idx_to_nodes.keys())
|
|
150
|
+
expected_ids = list(range(len(ordered_ids)))
|
|
151
|
+
if ordered_ids != expected_ids:
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f"CTCS constraint idx values must form a contiguous block starting from 0. "
|
|
154
|
+
f"Got {ordered_ids}, expected {expected_ids}"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Extract intervals in ascending idx order
|
|
158
|
+
node_intervals = [idx_to_nodes[i] for i in ordered_ids]
|
|
159
|
+
num_augmented_states = len(ordered_ids)
|
|
160
|
+
|
|
161
|
+
return constraints_ctcs, node_intervals, num_augmented_states
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def separate_constraints(constraint_set: ConstraintSet, n_nodes: int) -> ConstraintSet:
|
|
165
|
+
"""Separate and categorize constraints by type and convexity.
|
|
166
|
+
|
|
167
|
+
Moves constraints from `constraint_set.unsorted` into their appropriate
|
|
168
|
+
category fields (ctcs, nodal, nodal_convex, cross_node, cross_node_convex).
|
|
169
|
+
|
|
170
|
+
Bare Constraint objects are automatically categorized:
|
|
171
|
+
- If they contain NodeReferences (from .at(k) calls), they become CrossNodeConstraint
|
|
172
|
+
- Otherwise, they become NodalConstraint applied at all nodes
|
|
173
|
+
|
|
174
|
+
Constraints within CTCS wrappers that have check_nodally=True are also extracted
|
|
175
|
+
and added to the nodal constraint lists.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
constraint_set: ConstraintSet with raw constraints in `unsorted` field
|
|
179
|
+
n_nodes: Total number of nodes in the trajectory
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
The same ConstraintSet with `unsorted` drained and categories populated
|
|
183
|
+
|
|
184
|
+
Raises:
|
|
185
|
+
ValueError: If a constraint is not one of the expected types
|
|
186
|
+
ValueError: If a NodalConstraint contains NodeReferences (use bare Constraint instead)
|
|
187
|
+
ValueError: If a CTCS constraint contains NodeReferences
|
|
188
|
+
|
|
189
|
+
Example:
|
|
190
|
+
Separate and categorize constraints::
|
|
191
|
+
|
|
192
|
+
x = ox.State("x", shape=(3,))
|
|
193
|
+
constraint_set = ConstraintSet(unsorted=[
|
|
194
|
+
(x <= 5).over((0, 50)), # CTCS
|
|
195
|
+
(x >= 0).at([0, 10, 20]), # NodalConstraint
|
|
196
|
+
ox.Norm(x) <= 1, # Bare -> all nodes
|
|
197
|
+
x.at(5) - x.at(4) <= 0.1, # Bare with NodeRef -> cross-node
|
|
198
|
+
])
|
|
199
|
+
separate_constraints(constraint_set, n_nodes=50)
|
|
200
|
+
assert constraint_set.is_categorized
|
|
201
|
+
# Access via: constraint_set.ctcs, constraint_set.nodal, etc.
|
|
202
|
+
"""
|
|
203
|
+
from openscvx.symbolic.lower import _contains_node_reference
|
|
204
|
+
|
|
205
|
+
# Process all constraints from unsorted
|
|
206
|
+
for c in constraint_set.unsorted:
|
|
207
|
+
if isinstance(c, CTCS):
|
|
208
|
+
# Validate that CTCS constraints don't contain NodeReferences
|
|
209
|
+
if _contains_node_reference(c.constraint):
|
|
210
|
+
raise ValueError(
|
|
211
|
+
"CTCS constraints cannot contain NodeReferences (.at(k)). "
|
|
212
|
+
"Cross-node constraints should be specified as bare Constraint objects. "
|
|
213
|
+
f"Constraint: {c.constraint}"
|
|
214
|
+
)
|
|
215
|
+
# Normalize None to full horizon
|
|
216
|
+
c.nodes = c.nodes or (0, n_nodes)
|
|
217
|
+
constraint_set.ctcs.append(c)
|
|
218
|
+
|
|
219
|
+
elif isinstance(c, NodalConstraint):
|
|
220
|
+
# NodalConstraint means user explicitly called .at([...])
|
|
221
|
+
# Cross-node constraints should NOT use .at([...]) wrapper
|
|
222
|
+
if _contains_node_reference(c.constraint):
|
|
223
|
+
raise ValueError(
|
|
224
|
+
f"Cross-node constraints should not use .at([...]) wrapper. "
|
|
225
|
+
f"The constraint already references specific nodes via .at(k) inside the "
|
|
226
|
+
f"expression. Remove the outer .at([...]) wrapper and use the bare "
|
|
227
|
+
f"constraint directly. "
|
|
228
|
+
f"Constraint: {c.constraint}"
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# Regular nodal constraint - categorize by convexity
|
|
232
|
+
if c.constraint.is_convex:
|
|
233
|
+
constraint_set.nodal_convex.append(c)
|
|
234
|
+
else:
|
|
235
|
+
constraint_set.nodal.append(c)
|
|
236
|
+
|
|
237
|
+
elif isinstance(c, Constraint):
|
|
238
|
+
# Bare constraint - check if it's a cross-node constraint
|
|
239
|
+
if _contains_node_reference(c):
|
|
240
|
+
# Cross-node constraint: wrap in CrossNodeConstraint
|
|
241
|
+
cross_node = CrossNodeConstraint(c)
|
|
242
|
+
if c.is_convex:
|
|
243
|
+
constraint_set.cross_node_convex.append(cross_node)
|
|
244
|
+
else:
|
|
245
|
+
constraint_set.cross_node.append(cross_node)
|
|
246
|
+
else:
|
|
247
|
+
# Regular constraint: apply at all nodes
|
|
248
|
+
all_nodes = list(range(n_nodes))
|
|
249
|
+
nodal_constraint = NodalConstraint(c, all_nodes)
|
|
250
|
+
if c.is_convex:
|
|
251
|
+
constraint_set.nodal_convex.append(nodal_constraint)
|
|
252
|
+
else:
|
|
253
|
+
constraint_set.nodal.append(nodal_constraint)
|
|
254
|
+
|
|
255
|
+
else:
|
|
256
|
+
raise ValueError(
|
|
257
|
+
"Constraints must be `Constraint`, `NodalConstraint`, or `CTCS`, "
|
|
258
|
+
f"got {type(c).__name__}"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
# Clear unsorted now that all have been categorized
|
|
262
|
+
constraint_set.unsorted = []
|
|
263
|
+
|
|
264
|
+
# Add nodal constraints from CTCS constraints that have check_nodally=True
|
|
265
|
+
ctcs_nodal_constraints = get_nodal_constraints_from_ctcs(constraint_set.ctcs)
|
|
266
|
+
for constraint, interval in ctcs_nodal_constraints:
|
|
267
|
+
# CTCS check_nodally constraints cannot have NodeReferences (validated above)
|
|
268
|
+
# Convert CTCS interval (start, end) to list of nodes [start, start+1, ..., end-1]
|
|
269
|
+
interval_nodes = list(range(interval[0], interval[1]))
|
|
270
|
+
nodal_constraint = NodalConstraint(constraint, interval_nodes)
|
|
271
|
+
|
|
272
|
+
if constraint.is_convex:
|
|
273
|
+
constraint_set.nodal_convex.append(nodal_constraint)
|
|
274
|
+
else:
|
|
275
|
+
constraint_set.nodal.append(nodal_constraint)
|
|
276
|
+
|
|
277
|
+
# Validate cross-node constraints (bounds and variable consistency)
|
|
278
|
+
from openscvx.symbolic.preprocessing import validate_cross_node_constraint
|
|
279
|
+
|
|
280
|
+
for cross_node_constraint in constraint_set.cross_node + constraint_set.cross_node_convex:
|
|
281
|
+
validate_cross_node_constraint(cross_node_constraint, n_nodes)
|
|
282
|
+
|
|
283
|
+
return constraint_set
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def decompose_vector_nodal_constraints(
|
|
287
|
+
constraints_nodal: List[NodalConstraint],
|
|
288
|
+
) -> List[NodalConstraint]:
|
|
289
|
+
"""Decompose vector-valued nodal constraints into scalar constraints.
|
|
290
|
+
|
|
291
|
+
Decomposes vector constraints into individual scalar constraints, which is necessary
|
|
292
|
+
for nonconvex nodal constraints that are lowered to JAX functions. The JAX-to-CVXPY
|
|
293
|
+
interface expects scalar constraint values at each node.
|
|
294
|
+
|
|
295
|
+
For example, a constraint with shape (3,) is decomposed into 3 separate scalar
|
|
296
|
+
constraints using indexing. CTCS constraints don't need decomposition since they
|
|
297
|
+
handle vector values internally.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
constraints_nodal (List[NodalConstraint]): List of NodalConstraint objects
|
|
301
|
+
(must be canonicalized)
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
List of NodalConstraint objects with vector constraints decomposed into scalars.
|
|
305
|
+
Scalar constraints are passed through unchanged.
|
|
306
|
+
|
|
307
|
+
Note:
|
|
308
|
+
Constraints are assumed to be in canonical form: residual <= 0 or residual == 0,
|
|
309
|
+
where residual is the lhs of the constraint.
|
|
310
|
+
|
|
311
|
+
Example:
|
|
312
|
+
Decompose vector constraint into 3 constraints:
|
|
313
|
+
|
|
314
|
+
x = ox.State("x", shape=(3,))
|
|
315
|
+
constraint = (x <= 5).at([0, 10, 20]) # Vector constraint, shape (3,)
|
|
316
|
+
decomposed = decompose_vector_nodal_constraints([constraint])
|
|
317
|
+
# Returns 3 constraints: x[0] <= 5, x[1] <= 5, x[2] <= 5
|
|
318
|
+
"""
|
|
319
|
+
decomposed_constraints = []
|
|
320
|
+
|
|
321
|
+
for nodal_constraint in constraints_nodal:
|
|
322
|
+
constraint = nodal_constraint.constraint
|
|
323
|
+
nodes = nodal_constraint.nodes
|
|
324
|
+
|
|
325
|
+
try:
|
|
326
|
+
# Get the shape of the constraint residual
|
|
327
|
+
# Canonicalized constraints are in form: residual <= 0 or residual == 0
|
|
328
|
+
residual_shape = constraint.lhs.check_shape()
|
|
329
|
+
|
|
330
|
+
# Check if this is a vector constraint
|
|
331
|
+
# Decompose ALL vector-shaped constraints (including shape=(1,)) to avoid
|
|
332
|
+
# vmap adding an extra dimension when stacking results
|
|
333
|
+
if len(residual_shape) > 0:
|
|
334
|
+
# Vector constraint - decompose into scalar constraints
|
|
335
|
+
total_elements = int(np.prod(residual_shape))
|
|
336
|
+
|
|
337
|
+
for i in range(total_elements):
|
|
338
|
+
# Create indexed version: residual[i] <= 0 or residual[i] == 0
|
|
339
|
+
indexed_lhs = Index(constraint.lhs, i)
|
|
340
|
+
indexed_rhs = constraint.rhs # Should be Constant(0)
|
|
341
|
+
indexed_constraint = constraint.__class__(indexed_lhs, indexed_rhs)
|
|
342
|
+
decomposed_constraints.append(NodalConstraint(indexed_constraint, nodes))
|
|
343
|
+
else:
|
|
344
|
+
# Scalar constraint - keep as is
|
|
345
|
+
decomposed_constraints.append(nodal_constraint)
|
|
346
|
+
|
|
347
|
+
except Exception:
|
|
348
|
+
# If shape analysis fails, keep original constraint for backward compatibility
|
|
349
|
+
decomposed_constraints.append(nodal_constraint)
|
|
350
|
+
|
|
351
|
+
return decomposed_constraints
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def get_nodal_constraints_from_ctcs(
|
|
355
|
+
constraints_ctcs: List[CTCS],
|
|
356
|
+
) -> List[tuple[Constraint, tuple[int, int]]]:
|
|
357
|
+
"""Extract constraints from CTCS wrappers that should be checked nodally.
|
|
358
|
+
|
|
359
|
+
Some CTCS constraints have the check_nodally flag set, indicating that the
|
|
360
|
+
underlying constraint should be enforced both continuously (via CTCS) and
|
|
361
|
+
discretely at the nodes. This function extracts those underlying constraints
|
|
362
|
+
along with their node intervals.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
constraints_ctcs: List of CTCS constraint wrappers
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
List of tuples (constraint, nodes) where:
|
|
369
|
+
- constraint: The underlying Constraint object from CTCS with check_nodally=True
|
|
370
|
+
- nodes: The (start, end) interval from the CTCS wrapper
|
|
371
|
+
|
|
372
|
+
Example:
|
|
373
|
+
Extract CTCS constraint that should also be checked at nodes:
|
|
374
|
+
|
|
375
|
+
x = ox.State("x", shape=(3,))
|
|
376
|
+
constraint = (x <= 5).over((10, 50), check_nodally=True)
|
|
377
|
+
nodal = get_nodal_constraints_from_ctcs([constraint])
|
|
378
|
+
|
|
379
|
+
Returns [(x <= 5, (10, 50))] to be enforced at nodes 10 through 49
|
|
380
|
+
"""
|
|
381
|
+
nodal_ctcs = []
|
|
382
|
+
for ctcs in constraints_ctcs:
|
|
383
|
+
if ctcs.check_nodally:
|
|
384
|
+
nodal_ctcs.append((ctcs.constraint, ctcs.nodes))
|
|
385
|
+
return nodal_ctcs
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def augment_with_time_state(
|
|
389
|
+
states: List[State],
|
|
390
|
+
constraints: ConstraintSet,
|
|
391
|
+
time_initial: float | tuple,
|
|
392
|
+
time_final: float | tuple,
|
|
393
|
+
time_min: float,
|
|
394
|
+
time_max: float,
|
|
395
|
+
N: int,
|
|
396
|
+
time_scaling_min: Optional[float] = None,
|
|
397
|
+
time_scaling_max: Optional[float] = None,
|
|
398
|
+
) -> Tuple[List[State], ConstraintSet]:
|
|
399
|
+
"""Augment problem with a time state variable.
|
|
400
|
+
|
|
401
|
+
Creates a time state variable if one doesn't already exist and adds it to the
|
|
402
|
+
states list. Also adds CTCS constraints to enforce time bounds continuously
|
|
403
|
+
throughout the trajectory.
|
|
404
|
+
|
|
405
|
+
The time state tracks physical time along the trajectory and is used for
|
|
406
|
+
time-optimal control problems. Boundary conditions can be fixed values or
|
|
407
|
+
free variables with initial guesses.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
states: List of State objects (will not be modified, copy is returned)
|
|
411
|
+
constraints: ConstraintSet with unsorted constraints (will be modified in place)
|
|
412
|
+
time_initial: Initial time boundary condition:
|
|
413
|
+
- float: Fixed initial time
|
|
414
|
+
- tuple: ("free", guess) for free initial time with initial guess
|
|
415
|
+
time_final: Final time boundary condition (same format as time_initial)
|
|
416
|
+
time_min: Minimum bound for time variable throughout trajectory
|
|
417
|
+
time_max: Maximum bound for time variable throughout trajectory
|
|
418
|
+
N: Number of discretization nodes (for initial guess generation)
|
|
419
|
+
|
|
420
|
+
Returns:
|
|
421
|
+
Tuple of:
|
|
422
|
+
- Updated states list (original + time state if created)
|
|
423
|
+
- The same ConstraintSet with time CTCS constraints added to unsorted
|
|
424
|
+
|
|
425
|
+
Note:
|
|
426
|
+
If a state named "time" already exists, it is not modified and no
|
|
427
|
+
constraints are added.
|
|
428
|
+
|
|
429
|
+
Example:
|
|
430
|
+
Get augmented states::
|
|
431
|
+
|
|
432
|
+
x = ox.State("x", shape=(3,))
|
|
433
|
+
constraints = ConstraintSet()
|
|
434
|
+
states_aug, constraints = augment_with_time_state(
|
|
435
|
+
states=[x],
|
|
436
|
+
constraints=constraints,
|
|
437
|
+
time_initial=0.0,
|
|
438
|
+
time_final=("free", 10.0),
|
|
439
|
+
time_min=0.0,
|
|
440
|
+
time_max=100.0,
|
|
441
|
+
N=50
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
states_aug now includes time state with initial=0, final=free
|
|
445
|
+
"""
|
|
446
|
+
# Create copy of states to avoid mutating input
|
|
447
|
+
states_aug = list(states)
|
|
448
|
+
|
|
449
|
+
# Check if a time state already exists
|
|
450
|
+
time_state = None
|
|
451
|
+
for state in states_aug:
|
|
452
|
+
if state.name == "time":
|
|
453
|
+
time_state = state
|
|
454
|
+
break
|
|
455
|
+
|
|
456
|
+
if time_state is None:
|
|
457
|
+
# Create time State only if it doesn't exist
|
|
458
|
+
time_state = State("time", shape=(1,))
|
|
459
|
+
time_state.min = np.array([time_min])
|
|
460
|
+
time_state.max = np.array([time_max])
|
|
461
|
+
|
|
462
|
+
# Set time boundary conditions
|
|
463
|
+
time_state.initial = [time_initial]
|
|
464
|
+
time_state.final = [time_final]
|
|
465
|
+
|
|
466
|
+
# Create initial guess for time (linear interpolation)
|
|
467
|
+
time_guess_start = (
|
|
468
|
+
time_state.initial[0]
|
|
469
|
+
if isinstance(time_state.initial[0], (int, float))
|
|
470
|
+
else time_state.initial[0][1]
|
|
471
|
+
)
|
|
472
|
+
time_guess_end = (
|
|
473
|
+
time_state.final[0]
|
|
474
|
+
if isinstance(time_state.final[0], (int, float))
|
|
475
|
+
else time_state.final[0][1]
|
|
476
|
+
)
|
|
477
|
+
time_state.guess = np.linspace(time_guess_start, time_guess_end, N).reshape(-1, 1)
|
|
478
|
+
|
|
479
|
+
# Transfer scaling_min/max from Time object if provided
|
|
480
|
+
if time_scaling_min is not None:
|
|
481
|
+
time_state.scaling_min = np.array([time_scaling_min])
|
|
482
|
+
if time_scaling_max is not None:
|
|
483
|
+
time_state.scaling_max = np.array([time_scaling_max])
|
|
484
|
+
|
|
485
|
+
# Add time state to the list
|
|
486
|
+
states_aug.append(time_state)
|
|
487
|
+
|
|
488
|
+
# Add CTCS constraints for time bounds to unsorted
|
|
489
|
+
constraints.unsorted.append(CTCS(time_state <= time_state.max))
|
|
490
|
+
constraints.unsorted.append(CTCS(time_state.min <= time_state))
|
|
491
|
+
|
|
492
|
+
return states_aug, constraints
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def augment_dynamics_with_ctcs(
|
|
496
|
+
xdot: Expr,
|
|
497
|
+
states: List[State],
|
|
498
|
+
controls: List[Control],
|
|
499
|
+
constraints_ctcs: List[CTCS],
|
|
500
|
+
N: int,
|
|
501
|
+
licq_min: float = 0.0,
|
|
502
|
+
licq_max: float = 1e-4,
|
|
503
|
+
time_dilation_factor_min: float = 0.3,
|
|
504
|
+
time_dilation_factor_max: float = 3.0,
|
|
505
|
+
) -> Tuple[Expr, List[State], List[Control]]:
|
|
506
|
+
"""Augment dynamics with continuous-time constraint satisfaction states.
|
|
507
|
+
|
|
508
|
+
Implements the CTCS method by adding augmented states and time dilation control
|
|
509
|
+
to the original dynamics. For each group of CTCS constraints, an augmented state
|
|
510
|
+
is created whose dynamics are the penalty function of constraint violations.
|
|
511
|
+
|
|
512
|
+
The CTCS method enforces path constraints continuously by:
|
|
513
|
+
1. Creating augmented states with dynamics = penalty(constraint_violation)
|
|
514
|
+
2. Constraining augmented states to stay near zero (LICQ condition)
|
|
515
|
+
3. Adding time dilation control to slow down near constraint boundaries
|
|
516
|
+
|
|
517
|
+
The augmented dynamics become:
|
|
518
|
+
x_dot = f(x, u)
|
|
519
|
+
aug_dot = penalty(g(x, u)) # For each constraint group
|
|
520
|
+
time_dot = time_dilation
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
xdot: Original dynamics expression for states
|
|
524
|
+
states: List of state variables (must include a state named "time")
|
|
525
|
+
controls: List of control variables
|
|
526
|
+
constraints_ctcs: List of CTCS constraints (should be sorted and grouped)
|
|
527
|
+
N: Number of discretization nodes
|
|
528
|
+
licq_min: Minimum bound for augmented states (default: 0.0)
|
|
529
|
+
licq_max: Maximum bound for augmented states (default: 1e-4)
|
|
530
|
+
time_dilation_factor_min: Minimum time dilation factor (default: 0.3)
|
|
531
|
+
time_dilation_factor_max: Maximum time dilation factor (default: 3.0)
|
|
532
|
+
|
|
533
|
+
Returns:
|
|
534
|
+
Tuple of:
|
|
535
|
+
- Augmented dynamics expression (original + augmented state dynamics)
|
|
536
|
+
- Updated states list (original + augmented states)
|
|
537
|
+
- Updated controls list (original + time dilation control)
|
|
538
|
+
|
|
539
|
+
Raises:
|
|
540
|
+
ValueError: If no state named "time" is found in the states list
|
|
541
|
+
|
|
542
|
+
Example:
|
|
543
|
+
Augment dynamics with CTCS penalty states:
|
|
544
|
+
|
|
545
|
+
x = ox.State("x", shape=(3,))
|
|
546
|
+
u = ox.Control("u", shape=(2,))
|
|
547
|
+
time = ox.State("time", shape=(1,))
|
|
548
|
+
xdot = u @ A # Some dynamics
|
|
549
|
+
constraint = (ox.Norm(x) <= 1.0).over((0, 50))
|
|
550
|
+
xdot_aug, states_aug, controls_aug = augment_dynamics_with_ctcs(
|
|
551
|
+
xdot=xdot,
|
|
552
|
+
states=[x, time],
|
|
553
|
+
controls=[u],
|
|
554
|
+
constraints_ctcs=[constraint],
|
|
555
|
+
N=50
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
states_aug includes x, time, and _ctcs_aug_0,
|
|
559
|
+
controls_aug includes u and _time_dilation
|
|
560
|
+
"""
|
|
561
|
+
# Copy the original states and controls lists
|
|
562
|
+
states_augmented = list(states)
|
|
563
|
+
controls_augmented = list(controls)
|
|
564
|
+
|
|
565
|
+
if constraints_ctcs:
|
|
566
|
+
# Group penalty expressions by idx (constraints should already be sorted)
|
|
567
|
+
penalty_groups: Dict[int, List[Expr]] = {}
|
|
568
|
+
|
|
569
|
+
for ctcs in constraints_ctcs:
|
|
570
|
+
# Keep the CTCS wrapper intact to preserve node interval information
|
|
571
|
+
# The JAX lowerer's visit_ctcs() method will handle the conditional logic
|
|
572
|
+
|
|
573
|
+
# TODO: In the future, apply scaling here if ctcs has a scaling attribute
|
|
574
|
+
# if hasattr(ctcs, 'scaling') and ctcs.scaling != 1.0:
|
|
575
|
+
# ctcs = scale_ctcs(ctcs, scaling_factor)
|
|
576
|
+
|
|
577
|
+
if ctcs.idx not in penalty_groups:
|
|
578
|
+
penalty_groups[ctcs.idx] = []
|
|
579
|
+
penalty_groups[ctcs.idx].append(ctcs)
|
|
580
|
+
|
|
581
|
+
# Create augmented state expressions for each group
|
|
582
|
+
augmented_state_exprs = []
|
|
583
|
+
for idx in sorted(penalty_groups.keys()):
|
|
584
|
+
penalty_terms = penalty_groups[idx]
|
|
585
|
+
if len(penalty_terms) == 1:
|
|
586
|
+
augmented_state_expr = penalty_terms[0]
|
|
587
|
+
else:
|
|
588
|
+
augmented_state_expr = Add(*penalty_terms)
|
|
589
|
+
augmented_state_exprs.append(augmented_state_expr)
|
|
590
|
+
|
|
591
|
+
# Calculate number of augmented states from the penalty groups
|
|
592
|
+
num_augmented_states = len(penalty_groups)
|
|
593
|
+
|
|
594
|
+
# Create augmented state variables
|
|
595
|
+
for idx in range(num_augmented_states):
|
|
596
|
+
aug_var = State(f"_ctcs_aug_{idx}", shape=(1,))
|
|
597
|
+
aug_var.initial = np.array([licq_min]) # Set initial to respect bounds
|
|
598
|
+
aug_var.final = [("free", 0)]
|
|
599
|
+
aug_var.min = np.array([licq_min])
|
|
600
|
+
aug_var.max = np.array([licq_max])
|
|
601
|
+
# Set guess to licq_min as well
|
|
602
|
+
aug_var.guess = np.full([N, 1], licq_min) # N x num augmented states
|
|
603
|
+
states_augmented.append(aug_var)
|
|
604
|
+
|
|
605
|
+
# Concatenate with original dynamics
|
|
606
|
+
xdot_aug = Concat(xdot, *augmented_state_exprs)
|
|
607
|
+
else:
|
|
608
|
+
xdot_aug = xdot
|
|
609
|
+
|
|
610
|
+
time_dilation = Control("_time_dilation", shape=(1,))
|
|
611
|
+
|
|
612
|
+
# Set up time dilation bounds and initial guess
|
|
613
|
+
# Find the time state by name
|
|
614
|
+
time_state = None
|
|
615
|
+
for state in states:
|
|
616
|
+
if state.name == "time":
|
|
617
|
+
time_state = state
|
|
618
|
+
break
|
|
619
|
+
|
|
620
|
+
if time_state is None:
|
|
621
|
+
raise ValueError("No state named 'time' found in states list")
|
|
622
|
+
|
|
623
|
+
time_final = time_state.final[0]
|
|
624
|
+
time_dilation.min = np.array([time_dilation_factor_min * time_final])
|
|
625
|
+
time_dilation.max = np.array([time_dilation_factor_max * time_final])
|
|
626
|
+
time_dilation.guess = np.ones([N, 1]) * time_final
|
|
627
|
+
|
|
628
|
+
controls_augmented.append(time_dilation)
|
|
629
|
+
|
|
630
|
+
return xdot_aug, states_augmented, controls_augmented
|