openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,419 @@
|
|
|
1
|
+
"""Lowering logic for bring-your-own-functions (byof).
|
|
2
|
+
|
|
3
|
+
This module handles integration of user-provided JAX functions into the
|
|
4
|
+
lowered problem representation, including dynamics splicing and constraint
|
|
5
|
+
addition.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING, List, Tuple
|
|
9
|
+
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
import numpy as np
|
|
13
|
+
from jax import jacfwd
|
|
14
|
+
from jax.lax import cond
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from openscvx.lowered.unified import UnifiedState
|
|
18
|
+
from openscvx.symbolic.expr.state import State
|
|
19
|
+
|
|
20
|
+
from openscvx.lowered import (
|
|
21
|
+
Dynamics,
|
|
22
|
+
LoweredCrossNodeConstraint,
|
|
23
|
+
LoweredJaxConstraints,
|
|
24
|
+
LoweredNodalConstraint,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
__all__ = ["apply_byof"]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def apply_byof(
|
|
31
|
+
byof: dict,
|
|
32
|
+
dynamics: Dynamics,
|
|
33
|
+
dynamics_prop: Dynamics,
|
|
34
|
+
jax_constraints: LoweredJaxConstraints,
|
|
35
|
+
x_unified: "UnifiedState",
|
|
36
|
+
x_prop_unified: "UnifiedState",
|
|
37
|
+
u_unified: "UnifiedState",
|
|
38
|
+
states: List["State"],
|
|
39
|
+
states_prop: List["State"],
|
|
40
|
+
N: int,
|
|
41
|
+
) -> Tuple[Dynamics, Dynamics, LoweredJaxConstraints, "UnifiedState", "UnifiedState"]:
|
|
42
|
+
"""Apply bring-your-own-functions (byof) to augment lowered problem.
|
|
43
|
+
|
|
44
|
+
Handles raw JAX functions provided by expert users, including:
|
|
45
|
+
- dynamics: Raw JAX functions for specific state derivatives
|
|
46
|
+
- nodal_constraints: Point-wise constraints at each node
|
|
47
|
+
- cross_nodal_constraints: Constraints coupling multiple nodes
|
|
48
|
+
- ctcs_constraints: Continuous-time constraint satisfaction via dynamics augmentation
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
byof: Dict with keys "dynamics", "nodal_constraints", "cross_nodal_constraints",
|
|
52
|
+
"ctcs_constraints"
|
|
53
|
+
dynamics: Lowered optimization dynamics to potentially augment
|
|
54
|
+
dynamics_prop: Lowered propagation dynamics to potentially augment
|
|
55
|
+
jax_constraints: Lowered JAX constraints to append to
|
|
56
|
+
x_unified: Unified optimization state interface to potentially augment
|
|
57
|
+
x_prop_unified: Unified propagation state interface to potentially augment
|
|
58
|
+
u_unified: Unified control interface for validation
|
|
59
|
+
states: List of State objects for optimization (with _slice attributes)
|
|
60
|
+
states_prop: List of State objects for propagation (with _slice attributes)
|
|
61
|
+
N: Number of nodes in the trajectory
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Tuple of (dynamics, dynamics_prop, jax_constraints, x_unified, x_prop_unified)
|
|
65
|
+
|
|
66
|
+
Example:
|
|
67
|
+
>>> dynamics, dynamics_prop, constraints, x_unified, x_prop_unified = apply_byof(
|
|
68
|
+
... byof, dynamics, dynamics_prop, jax_constraints,
|
|
69
|
+
... x_unified, x_prop_unified, u_unified, states, states_prop, N
|
|
70
|
+
... )
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
# Note: byof validation happens earlier in Problem.__init__ to fail fast
|
|
74
|
+
# Handle byof dynamics by splicing in raw JAX functions at the correct slices
|
|
75
|
+
byof_dynamics = byof.get("dynamics", {})
|
|
76
|
+
if byof_dynamics:
|
|
77
|
+
# Build mapping from state name to slice for optimization states
|
|
78
|
+
state_slices = {state.name: state._slice for state in states}
|
|
79
|
+
state_slices_prop = {state.name: state._slice for state in states_prop}
|
|
80
|
+
|
|
81
|
+
def _make_composite_dynamics(orig_f, byof_fns, slices_map):
|
|
82
|
+
"""Create composite dynamics combining symbolic and byof state derivatives.
|
|
83
|
+
|
|
84
|
+
This factory splices user-provided byof dynamics into the unified dynamics
|
|
85
|
+
function at the appropriate slice indices, replacing the symbolic dynamics
|
|
86
|
+
for specific states while preserving the rest.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
orig_f: Original unified dynamics (x, u, node, params) -> xdot
|
|
90
|
+
byof_fns: Dict mapping state names to byof dynamics functions
|
|
91
|
+
slices_map: Dict mapping state names to slice objects for indexing
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Composite dynamics function with byof derivatives spliced in
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
def composite_f(x, u, node, params):
|
|
98
|
+
# Start with symbolic/default dynamics for all states
|
|
99
|
+
xdot = orig_f(x, u, node, params)
|
|
100
|
+
|
|
101
|
+
# Splice in byof dynamics for specific states
|
|
102
|
+
for state_name, byof_fn in byof_fns.items():
|
|
103
|
+
sl = slices_map[state_name]
|
|
104
|
+
# Replace the derivative for this state with the byof result
|
|
105
|
+
xdot = xdot.at[sl].set(byof_fn(x, u, node, params))
|
|
106
|
+
|
|
107
|
+
return xdot
|
|
108
|
+
|
|
109
|
+
return composite_f
|
|
110
|
+
|
|
111
|
+
# Create composite optimization dynamics
|
|
112
|
+
composite_f = _make_composite_dynamics(dynamics.f, byof_dynamics, state_slices)
|
|
113
|
+
dynamics = Dynamics(
|
|
114
|
+
f=composite_f,
|
|
115
|
+
A=jacfwd(composite_f, argnums=0),
|
|
116
|
+
B=jacfwd(composite_f, argnums=1),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Create composite propagation dynamics
|
|
120
|
+
composite_f_prop = _make_composite_dynamics(
|
|
121
|
+
dynamics_prop.f, byof_dynamics, state_slices_prop
|
|
122
|
+
)
|
|
123
|
+
dynamics_prop = Dynamics(
|
|
124
|
+
f=composite_f_prop,
|
|
125
|
+
A=jacfwd(composite_f_prop, argnums=0),
|
|
126
|
+
B=jacfwd(composite_f_prop, argnums=1),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Handle nodal constraints
|
|
130
|
+
# Note: Validation happens earlier in Problem.__init__ via validate_byof
|
|
131
|
+
for constraint_spec in byof.get("nodal_constraints", []):
|
|
132
|
+
fn = constraint_spec["constraint_fn"]
|
|
133
|
+
nodes = constraint_spec.get("nodes", list(range(N))) # Default: all nodes
|
|
134
|
+
|
|
135
|
+
# Normalize negative node indices (validation already done in validate_byof)
|
|
136
|
+
normalized_nodes = [node if node >= 0 else N + node for node in nodes]
|
|
137
|
+
|
|
138
|
+
constraint = LoweredNodalConstraint(
|
|
139
|
+
func=jax.vmap(fn, in_axes=(0, 0, None, None)),
|
|
140
|
+
grad_g_x=jax.vmap(jacfwd(fn, argnums=0), in_axes=(0, 0, None, None)),
|
|
141
|
+
grad_g_u=jax.vmap(jacfwd(fn, argnums=1), in_axes=(0, 0, None, None)),
|
|
142
|
+
nodes=normalized_nodes,
|
|
143
|
+
)
|
|
144
|
+
jax_constraints.nodal.append(constraint)
|
|
145
|
+
|
|
146
|
+
# Handle cross-nodal constraints
|
|
147
|
+
for fn in byof.get("cross_nodal_constraints", []):
|
|
148
|
+
constraint = LoweredCrossNodeConstraint(
|
|
149
|
+
func=fn,
|
|
150
|
+
grad_g_X=jacfwd(fn, argnums=0),
|
|
151
|
+
grad_g_U=jacfwd(fn, argnums=1),
|
|
152
|
+
)
|
|
153
|
+
jax_constraints.cross_node.append(constraint)
|
|
154
|
+
|
|
155
|
+
# Handle CTCS constraints by augmenting dynamics
|
|
156
|
+
# Built-in penalty functions
|
|
157
|
+
def _penalty_square(r):
|
|
158
|
+
return jnp.maximum(r, 0.0) ** 2
|
|
159
|
+
|
|
160
|
+
def _penalty_l1(r):
|
|
161
|
+
return jnp.maximum(r, 0.0)
|
|
162
|
+
|
|
163
|
+
def _penalty_huber(r, delta=1.0):
|
|
164
|
+
abs_r = jnp.maximum(r, 0.0)
|
|
165
|
+
return jnp.where(abs_r <= delta, 0.5 * abs_r**2, delta * (abs_r - 0.5 * delta))
|
|
166
|
+
|
|
167
|
+
_PENALTY_FUNCTIONS = {
|
|
168
|
+
"square": _penalty_square,
|
|
169
|
+
"l1": _penalty_l1,
|
|
170
|
+
"huber": _penalty_huber,
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
# Determine which symbolic CTCS idx values already exist
|
|
174
|
+
# Symbolic augmented states are named "_ctcs_aug_{i}" where i is sequential
|
|
175
|
+
# and corresponds to sorted symbolic idx values (0, 1, 2, ...)
|
|
176
|
+
symbolic_ctcs_idx = []
|
|
177
|
+
for state in states:
|
|
178
|
+
if state.name.startswith("_ctcs_aug_"):
|
|
179
|
+
try:
|
|
180
|
+
aug_idx = int(state.name.split("_")[-1])
|
|
181
|
+
symbolic_ctcs_idx.append(aug_idx)
|
|
182
|
+
except (ValueError, IndexError):
|
|
183
|
+
pass
|
|
184
|
+
|
|
185
|
+
# Symbolic CTCS creates augmented states with sequential idx: 0, 1, 2, ...
|
|
186
|
+
# so max_symbolic_idx = len(symbolic_ctcs_idx) - 1 (or -1 if none exist)
|
|
187
|
+
max_symbolic_idx = len(symbolic_ctcs_idx) - 1 if symbolic_ctcs_idx else -1
|
|
188
|
+
|
|
189
|
+
# Build idx -> augmented_state_slice mapping for existing symbolic CTCS
|
|
190
|
+
# Augmented states appear after regular states in the unified vector
|
|
191
|
+
# We'll determine the slice by finding the state in the states list
|
|
192
|
+
idx_to_aug_slice = {}
|
|
193
|
+
for state in states:
|
|
194
|
+
if state.name.startswith("_ctcs_aug_"):
|
|
195
|
+
try:
|
|
196
|
+
aug_idx = int(state.name.split("_")[-1])
|
|
197
|
+
# The actual idx value IS the sequential index for symbolic CTCS
|
|
198
|
+
# (they're created with idx 0, 1, 2, ... in sorted order)
|
|
199
|
+
idx_to_aug_slice[aug_idx] = state._slice
|
|
200
|
+
except (ValueError, IndexError, AttributeError):
|
|
201
|
+
pass
|
|
202
|
+
|
|
203
|
+
# Group BYOF CTCS constraints by idx (default to 0)
|
|
204
|
+
byof_ctcs_groups = {}
|
|
205
|
+
for ctcs_spec in byof.get("ctcs_constraints", []):
|
|
206
|
+
idx = ctcs_spec.get("idx", 0)
|
|
207
|
+
if idx not in byof_ctcs_groups:
|
|
208
|
+
byof_ctcs_groups[idx] = []
|
|
209
|
+
byof_ctcs_groups[idx].append(ctcs_spec)
|
|
210
|
+
|
|
211
|
+
# Validate that byof idx values don't create gaps
|
|
212
|
+
# All idx must form contiguous sequence: [0, 1, 2, ..., max_idx]
|
|
213
|
+
if byof_ctcs_groups:
|
|
214
|
+
all_idx = sorted(set(range(max_symbolic_idx + 1)) | set(byof_ctcs_groups.keys()))
|
|
215
|
+
expected_idx = list(range(len(all_idx)))
|
|
216
|
+
if all_idx != expected_idx:
|
|
217
|
+
raise ValueError(
|
|
218
|
+
f"BYOF CTCS idx values create non-contiguous sequence. "
|
|
219
|
+
f"Symbolic CTCS has idx=[{', '.join(map(str, range(max_symbolic_idx + 1)))}], "
|
|
220
|
+
f"combined with byof idx={sorted(byof_ctcs_groups.keys())} gives {all_idx}. "
|
|
221
|
+
f"Expected contiguous sequence {expected_idx}. "
|
|
222
|
+
f"Byof idx must either match existing symbolic idx or be sequential after them."
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Process each idx group
|
|
226
|
+
for idx in sorted(byof_ctcs_groups.keys()):
|
|
227
|
+
specs = byof_ctcs_groups[idx]
|
|
228
|
+
|
|
229
|
+
# Collect all penalty functions for this idx
|
|
230
|
+
penalty_fns = []
|
|
231
|
+
for spec in specs:
|
|
232
|
+
constraint_fn = spec["constraint_fn"]
|
|
233
|
+
penalty_spec = spec.get("penalty", "square")
|
|
234
|
+
over_interval = spec.get("over", None) # Node interval (start, end) or None
|
|
235
|
+
|
|
236
|
+
if callable(penalty_spec):
|
|
237
|
+
penalty_func = penalty_spec
|
|
238
|
+
else:
|
|
239
|
+
penalty_func = _PENALTY_FUNCTIONS[penalty_spec]
|
|
240
|
+
|
|
241
|
+
# Create a combined constraint+penalty function
|
|
242
|
+
def _make_penalty_fn(cons_fn, pen_func, over):
|
|
243
|
+
"""Factory to capture constraint, penalty functions, and node interval.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
cons_fn: Constraint function (x, u, node, params) -> scalar residual
|
|
247
|
+
pen_func: Penalty function (residual) -> penalty value
|
|
248
|
+
over: Optional (start, end) tuple for conditional activation
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Penalty function that conditionally activates based on node interval
|
|
252
|
+
"""
|
|
253
|
+
|
|
254
|
+
def penalty_fn(x, u, node, params):
|
|
255
|
+
# Compute penalty for the constraint violation
|
|
256
|
+
residual = cons_fn(x, u, node, params)
|
|
257
|
+
penalty_value = pen_func(residual)
|
|
258
|
+
|
|
259
|
+
# Apply conditional logic if over interval is specified
|
|
260
|
+
if over is not None:
|
|
261
|
+
start_node, end_node = over
|
|
262
|
+
# Extract scalar from node (which may be array or scalar)
|
|
263
|
+
# Keep as JAX array for tracing compatibility
|
|
264
|
+
node_scalar = jnp.atleast_1d(node)[0]
|
|
265
|
+
is_active = (start_node <= node_scalar) & (node_scalar < end_node)
|
|
266
|
+
|
|
267
|
+
# Use jax.lax.cond for JAX-traceable conditional evaluation
|
|
268
|
+
# Penalty is active only when node is in [start, end)
|
|
269
|
+
return cond(
|
|
270
|
+
is_active,
|
|
271
|
+
lambda _: penalty_value,
|
|
272
|
+
lambda _: 0.0,
|
|
273
|
+
operand=None,
|
|
274
|
+
)
|
|
275
|
+
else:
|
|
276
|
+
# Always active if no interval specified
|
|
277
|
+
return penalty_value
|
|
278
|
+
|
|
279
|
+
return penalty_fn
|
|
280
|
+
|
|
281
|
+
penalty_fns.append(_make_penalty_fn(constraint_fn, penalty_func, over_interval))
|
|
282
|
+
|
|
283
|
+
if idx in idx_to_aug_slice:
|
|
284
|
+
# This idx already exists from symbolic CTCS - add penalties to existing state
|
|
285
|
+
aug_slice = idx_to_aug_slice[idx]
|
|
286
|
+
|
|
287
|
+
def _make_ctcs_addition(orig_f, pen_fns, aug_sl):
|
|
288
|
+
"""Create dynamics that adds penalties to existing augmented state.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
orig_f: Original dynamics function
|
|
292
|
+
pen_fns: List of penalty functions to add
|
|
293
|
+
aug_sl: Slice of the augmented state to modify
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
Modified dynamics function
|
|
297
|
+
"""
|
|
298
|
+
|
|
299
|
+
def modified_f(x, u, node, params):
|
|
300
|
+
xdot = orig_f(x, u, node, params)
|
|
301
|
+
|
|
302
|
+
# Sum all penalties for this idx
|
|
303
|
+
total_penalty = sum(pen_fn(x, u, node, params) for pen_fn in pen_fns)
|
|
304
|
+
|
|
305
|
+
# Add to existing augmented state derivative
|
|
306
|
+
current_deriv = xdot[aug_sl]
|
|
307
|
+
xdot = xdot.at[aug_sl].set(current_deriv + total_penalty)
|
|
308
|
+
|
|
309
|
+
return xdot
|
|
310
|
+
|
|
311
|
+
return modified_f
|
|
312
|
+
|
|
313
|
+
# Modify both optimization and propagation dynamics
|
|
314
|
+
dynamics.f = _make_ctcs_addition(dynamics.f, penalty_fns, aug_slice)
|
|
315
|
+
dynamics.A = jacfwd(dynamics.f, argnums=0)
|
|
316
|
+
dynamics.B = jacfwd(dynamics.f, argnums=1)
|
|
317
|
+
|
|
318
|
+
dynamics_prop.f = _make_ctcs_addition(dynamics_prop.f, penalty_fns, aug_slice)
|
|
319
|
+
dynamics_prop.A = jacfwd(dynamics_prop.f, argnums=0)
|
|
320
|
+
dynamics_prop.B = jacfwd(dynamics_prop.f, argnums=1)
|
|
321
|
+
|
|
322
|
+
else:
|
|
323
|
+
# New idx - create new augmented state
|
|
324
|
+
# Use bounds/initial from first spec in this group
|
|
325
|
+
first_spec = specs[0]
|
|
326
|
+
bounds = first_spec.get("bounds", (0.0, 1e-4))
|
|
327
|
+
initial = first_spec.get("initial", bounds[0])
|
|
328
|
+
|
|
329
|
+
def _make_ctcs_new_state(orig_f, pen_fns):
|
|
330
|
+
"""Create dynamics augmented with new CTCS state.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
orig_f: Original dynamics function
|
|
334
|
+
pen_fns: List of penalty functions to sum
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
Augmented dynamics function
|
|
338
|
+
"""
|
|
339
|
+
|
|
340
|
+
def augmented_f(x, u, node, params):
|
|
341
|
+
xdot = orig_f(x, u, node, params)
|
|
342
|
+
|
|
343
|
+
# Sum all penalties for this new idx
|
|
344
|
+
total_penalty = sum(pen_fn(x, u, node, params) for pen_fn in pen_fns)
|
|
345
|
+
|
|
346
|
+
# Append as new augmented state derivative
|
|
347
|
+
return jnp.concatenate([xdot, jnp.atleast_1d(total_penalty)])
|
|
348
|
+
|
|
349
|
+
return augmented_f
|
|
350
|
+
|
|
351
|
+
# Augment optimization dynamics
|
|
352
|
+
aug_f = _make_ctcs_new_state(dynamics.f, penalty_fns)
|
|
353
|
+
dynamics = Dynamics(
|
|
354
|
+
f=aug_f,
|
|
355
|
+
A=jacfwd(aug_f, argnums=0),
|
|
356
|
+
B=jacfwd(aug_f, argnums=1),
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# Augment propagation dynamics
|
|
360
|
+
aug_f_prop = _make_ctcs_new_state(dynamics_prop.f, penalty_fns)
|
|
361
|
+
dynamics_prop = Dynamics(
|
|
362
|
+
f=aug_f_prop,
|
|
363
|
+
A=jacfwd(aug_f_prop, argnums=0),
|
|
364
|
+
B=jacfwd(aug_f_prop, argnums=1),
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
# Create State objects for the new augmented states
|
|
368
|
+
# This is necessary for CVXPy variable creation and other bookkeeping
|
|
369
|
+
from openscvx.symbolic.expr.state import State
|
|
370
|
+
|
|
371
|
+
# Create augmented state for optimization
|
|
372
|
+
aug_state = State(f"_ctcs_aug_{idx}", shape=(1,))
|
|
373
|
+
aug_state.min = np.array([bounds[0]])
|
|
374
|
+
aug_state.max = np.array([bounds[1]])
|
|
375
|
+
aug_state.initial = np.array([initial])
|
|
376
|
+
aug_state.final = [("free", 0.0)]
|
|
377
|
+
aug_state.guess = np.full((N, 1), initial)
|
|
378
|
+
|
|
379
|
+
# Set _slice attribute for the new state
|
|
380
|
+
current_dim = x_unified.shape[0]
|
|
381
|
+
aug_state._slice = slice(current_dim, current_dim + 1)
|
|
382
|
+
|
|
383
|
+
# Append to states list (in-place modification visible to caller)
|
|
384
|
+
states.append(aug_state)
|
|
385
|
+
|
|
386
|
+
# Create augmented state for propagation
|
|
387
|
+
aug_state_prop = State(f"_ctcs_aug_{idx}", shape=(1,))
|
|
388
|
+
aug_state_prop.min = np.array([bounds[0]])
|
|
389
|
+
aug_state_prop.max = np.array([bounds[1]])
|
|
390
|
+
aug_state_prop.initial = np.array([initial])
|
|
391
|
+
aug_state_prop.final = [("free", 0.0)]
|
|
392
|
+
aug_state_prop.guess = np.full((N, 1), initial)
|
|
393
|
+
|
|
394
|
+
# Set _slice attribute for the propagation state
|
|
395
|
+
current_dim_prop = x_prop_unified.shape[0]
|
|
396
|
+
aug_state_prop._slice = slice(current_dim_prop, current_dim_prop + 1)
|
|
397
|
+
|
|
398
|
+
# Append to states_prop list
|
|
399
|
+
states_prop.append(aug_state_prop)
|
|
400
|
+
|
|
401
|
+
# Add new augmented states to both unified state interfaces
|
|
402
|
+
x_unified.append(
|
|
403
|
+
min=bounds[0],
|
|
404
|
+
max=bounds[1],
|
|
405
|
+
guess=initial,
|
|
406
|
+
initial=initial,
|
|
407
|
+
final=0.0,
|
|
408
|
+
augmented=True,
|
|
409
|
+
)
|
|
410
|
+
x_prop_unified.append(
|
|
411
|
+
min=bounds[0],
|
|
412
|
+
max=bounds[1],
|
|
413
|
+
guess=initial,
|
|
414
|
+
initial=initial,
|
|
415
|
+
final=0.0,
|
|
416
|
+
augmented=True,
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
return dynamics, dynamics_prop, jax_constraints, x_unified, x_prop_unified
|