openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from openscvx.config import Config
|
|
5
|
+
from openscvx.integrators import solve_ivp_diffrax, solve_ivp_rk45
|
|
6
|
+
from openscvx.lowered import Dynamics
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def dVdt(
|
|
10
|
+
tau: float,
|
|
11
|
+
V: jnp.ndarray,
|
|
12
|
+
u_cur: np.ndarray,
|
|
13
|
+
u_next: np.ndarray,
|
|
14
|
+
state_dot: callable,
|
|
15
|
+
A: callable,
|
|
16
|
+
B: callable,
|
|
17
|
+
n_x: int,
|
|
18
|
+
n_u: int,
|
|
19
|
+
N: int,
|
|
20
|
+
dis_type: str,
|
|
21
|
+
params: dict,
|
|
22
|
+
) -> jnp.ndarray:
|
|
23
|
+
"""Compute the time derivative of the augmented state vector.
|
|
24
|
+
|
|
25
|
+
This function computes the time derivative of the augmented state vector V,
|
|
26
|
+
which includes the state, state transition matrix, and control influence matrix.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
tau (float): Current normalized time in [0,1].
|
|
30
|
+
V (jnp.ndarray): Augmented state vector.
|
|
31
|
+
u_cur (np.ndarray): Control input at current node.
|
|
32
|
+
u_next (np.ndarray): Control input at next node.
|
|
33
|
+
state_dot (callable): Function computing state derivatives.
|
|
34
|
+
A (callable): Function computing state Jacobian.
|
|
35
|
+
B (callable): Function computing control Jacobian.
|
|
36
|
+
n_x (int): Number of states.
|
|
37
|
+
n_u (int): Number of controls.
|
|
38
|
+
N (int): Number of nodes in trajectory.
|
|
39
|
+
dis_type (str): Discretization type ("ZOH" or "FOH").
|
|
40
|
+
**params: Additional parameters passed to state_dot, A, and B.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
jnp.ndarray: Time derivative of augmented state vector.
|
|
44
|
+
"""
|
|
45
|
+
# Define the nodes
|
|
46
|
+
nodes = jnp.arange(0, N - 1)
|
|
47
|
+
|
|
48
|
+
# Define indices for slicing the augmented state vector
|
|
49
|
+
i0 = 0
|
|
50
|
+
i1 = n_x
|
|
51
|
+
i2 = i1 + n_x * n_x
|
|
52
|
+
i3 = i2 + n_x * n_u
|
|
53
|
+
i4 = i3 + n_x * n_u
|
|
54
|
+
|
|
55
|
+
# Unflatten V
|
|
56
|
+
V = V.reshape(-1, i4)
|
|
57
|
+
|
|
58
|
+
# Compute the interpolation factor based on the discretization type
|
|
59
|
+
if dis_type == "ZOH":
|
|
60
|
+
beta = 0.0
|
|
61
|
+
elif dis_type == "FOH":
|
|
62
|
+
beta = (tau) * N
|
|
63
|
+
alpha = 1 - beta
|
|
64
|
+
|
|
65
|
+
# Interpolate the control input
|
|
66
|
+
u = u_cur + beta * (u_next - u_cur)
|
|
67
|
+
s = u[:, -1]
|
|
68
|
+
|
|
69
|
+
# Initialize the augmented Jacobians
|
|
70
|
+
dfdx = jnp.zeros((V.shape[0], n_x, n_x))
|
|
71
|
+
dfdu = jnp.zeros((V.shape[0], n_x, n_u))
|
|
72
|
+
|
|
73
|
+
# Ensure x_seq and u have the same batch size
|
|
74
|
+
x = V[:, :n_x]
|
|
75
|
+
u = u[: x.shape[0]]
|
|
76
|
+
|
|
77
|
+
# Compute the nonlinear propagation term
|
|
78
|
+
f = state_dot(x, u[:, :-1], nodes, params)
|
|
79
|
+
F = s[:, None] * f
|
|
80
|
+
|
|
81
|
+
# Evaluate the State Jacobian
|
|
82
|
+
dfdx = A(x, u[:, :-1], nodes, params)
|
|
83
|
+
sdfdx = s[:, None, None] * dfdx
|
|
84
|
+
|
|
85
|
+
# Evaluate the Control Jacobian
|
|
86
|
+
dfdu_veh = B(x, u[:, :-1], nodes, params)
|
|
87
|
+
dfdu = dfdu.at[:, :, :-1].set(s[:, None, None] * dfdu_veh)
|
|
88
|
+
dfdu = dfdu.at[:, :, -1].set(f)
|
|
89
|
+
|
|
90
|
+
# Stack up the results into the augmented state vector
|
|
91
|
+
# fmt: off
|
|
92
|
+
dVdt = jnp.zeros_like(V)
|
|
93
|
+
dVdt = dVdt.at[:, i0:i1].set(F)
|
|
94
|
+
dVdt = dVdt.at[:, i1:i2].set(
|
|
95
|
+
jnp.matmul(sdfdx, V[:, i1:i2].reshape(-1, n_x, n_x)).reshape(-1, n_x * n_x)
|
|
96
|
+
)
|
|
97
|
+
dVdt = dVdt.at[:, i2:i3].set(
|
|
98
|
+
(jnp.matmul(sdfdx, V[:, i2:i3].reshape(-1, n_x, n_u)) + dfdu * alpha).reshape(-1, n_x * n_u)
|
|
99
|
+
)
|
|
100
|
+
dVdt = dVdt.at[:, i3:i4].set(
|
|
101
|
+
(jnp.matmul(sdfdx, V[:, i3:i4].reshape(-1, n_x, n_u)) + dfdu * beta).reshape(-1, n_x * n_u)
|
|
102
|
+
)
|
|
103
|
+
# fmt: on
|
|
104
|
+
|
|
105
|
+
return dVdt.reshape(-1)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def calculate_discretization(
|
|
109
|
+
x,
|
|
110
|
+
u,
|
|
111
|
+
state_dot: callable,
|
|
112
|
+
A: callable,
|
|
113
|
+
B: callable,
|
|
114
|
+
settings: Config,
|
|
115
|
+
params: dict,
|
|
116
|
+
):
|
|
117
|
+
"""Calculate the discretized system matrices.
|
|
118
|
+
|
|
119
|
+
This function computes the discretized system matrices (A_bar, B_bar, C_bar)
|
|
120
|
+
and defect vector (z_bar) using numerical integration.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
x: State trajectory.
|
|
124
|
+
u: Control trajectory.
|
|
125
|
+
state_dot (callable): Function computing state derivatives.
|
|
126
|
+
A (callable): Function computing state Jacobian.
|
|
127
|
+
B (callable): Function computing control Jacobian.
|
|
128
|
+
settings: Configuration settings for OpenSCvx.
|
|
129
|
+
custom_integrator (bool): Whether to use custom RK45 integrator.
|
|
130
|
+
debug (bool): Whether to use debug mode.
|
|
131
|
+
solver (str): Name of the solver to use.
|
|
132
|
+
rtol (float): Relative tolerance for integration.
|
|
133
|
+
atol (float): Absolute tolerance for integration.
|
|
134
|
+
dis_type (str): Discretization type ("ZOH" or "FOH").
|
|
135
|
+
**kwargs: Additional parameters passed to state_dot, A, and B.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
tuple: (A_bar, B_bar, C_bar, z_bar, Vmulti) where:
|
|
139
|
+
- A_bar: Discretized state transition matrix
|
|
140
|
+
- B_bar: Discretized control influence matrix
|
|
141
|
+
- C_bar: Discretized control influence matrix for next node
|
|
142
|
+
- z_bar: Defect vector
|
|
143
|
+
- Vmulti: Full augmented state trajectory
|
|
144
|
+
"""
|
|
145
|
+
# Unpack settings
|
|
146
|
+
n_x = settings.sim.n_states
|
|
147
|
+
n_u = settings.sim.n_controls
|
|
148
|
+
|
|
149
|
+
N = settings.scp.n
|
|
150
|
+
|
|
151
|
+
# Define indices for slicing the augmented state vector
|
|
152
|
+
i0 = 0
|
|
153
|
+
i1 = n_x
|
|
154
|
+
i2 = i1 + n_x * n_x
|
|
155
|
+
i3 = i2 + n_x * n_u
|
|
156
|
+
i4 = i3 + n_x * n_u
|
|
157
|
+
|
|
158
|
+
# Initial augmented state
|
|
159
|
+
V0 = jnp.zeros((N - 1, i4))
|
|
160
|
+
V0 = V0.at[:, :n_x].set(x[:-1].astype(float))
|
|
161
|
+
V0 = V0.at[:, n_x : n_x + n_x * n_x].set(jnp.eye(n_x).reshape(1, -1).repeat(N - 1, axis=0))
|
|
162
|
+
|
|
163
|
+
# Choose integrator
|
|
164
|
+
integrator_args = dict(
|
|
165
|
+
u_cur=u[:-1].astype(float),
|
|
166
|
+
u_next=u[1:].astype(float),
|
|
167
|
+
state_dot=state_dot,
|
|
168
|
+
A=A,
|
|
169
|
+
B=B,
|
|
170
|
+
n_x=n_x,
|
|
171
|
+
n_u=n_u,
|
|
172
|
+
N=N,
|
|
173
|
+
dis_type=settings.dis.dis_type,
|
|
174
|
+
params=params, # Pass params as single dict
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Define dVdt wrapper using named arguments
|
|
178
|
+
def dVdt_wrapped(t, y):
|
|
179
|
+
return dVdt(t, y, **integrator_args)
|
|
180
|
+
|
|
181
|
+
# Choose integrator
|
|
182
|
+
if settings.dis.custom_integrator:
|
|
183
|
+
sol = solve_ivp_rk45(
|
|
184
|
+
dVdt_wrapped,
|
|
185
|
+
1.0 / (N - 1),
|
|
186
|
+
V0.reshape(-1),
|
|
187
|
+
args=(),
|
|
188
|
+
is_not_compiled=settings.dev.debug,
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
sol = solve_ivp_diffrax(
|
|
192
|
+
dVdt_wrapped,
|
|
193
|
+
1.0 / (N - 1),
|
|
194
|
+
V0.reshape(-1),
|
|
195
|
+
solver_name=settings.dis.solver,
|
|
196
|
+
rtol=settings.dis.rtol,
|
|
197
|
+
atol=settings.dis.atol,
|
|
198
|
+
args=(),
|
|
199
|
+
extra_kwargs=settings.dis.args,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
Vend = sol[-1].T.reshape(-1, i4)
|
|
203
|
+
Vmulti = sol.T
|
|
204
|
+
|
|
205
|
+
x_prop = Vend[:, i0:i1]
|
|
206
|
+
|
|
207
|
+
# Return as 3D arrays: (N-1, n_x, n_x) for A_bar, (N-1, n_x, n_u) for B_bar/C_bar
|
|
208
|
+
A_bar = Vend[:, i1:i2].reshape(N - 1, n_x, n_x)
|
|
209
|
+
B_bar = Vend[:, i2:i3].reshape(N - 1, n_x, n_u)
|
|
210
|
+
C_bar = Vend[:, i3:i4].reshape(N - 1, n_x, n_u)
|
|
211
|
+
|
|
212
|
+
return A_bar, B_bar, C_bar, x_prop, Vmulti
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def get_discretization_solver(dyn: Dynamics, settings: Config):
|
|
216
|
+
"""Create a discretization solver function.
|
|
217
|
+
|
|
218
|
+
This function creates a solver that computes the discretized system matrices
|
|
219
|
+
using the specified dynamics and settings.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
dyn (Dynamics): System dynamics object.
|
|
223
|
+
settings: Configuration settings for discretization.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
callable: A function that computes the discretized system matrices.
|
|
227
|
+
"""
|
|
228
|
+
return lambda x, u, params: calculate_discretization(
|
|
229
|
+
x=x,
|
|
230
|
+
u=u,
|
|
231
|
+
state_dot=dyn.f,
|
|
232
|
+
A=dyn.A,
|
|
233
|
+
B=dyn.B,
|
|
234
|
+
settings=settings,
|
|
235
|
+
params=params, # Pass as single dict
|
|
236
|
+
)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Expert-mode features for advanced users.
|
|
2
|
+
|
|
3
|
+
This module contains features for expert users who need fine-grained control
|
|
4
|
+
and are willing to bypass higher-level abstractions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from openscvx.expert.byof import (
|
|
8
|
+
ByofSpec,
|
|
9
|
+
CtcsConstraintSpec,
|
|
10
|
+
NodalConstraintSpec,
|
|
11
|
+
PenaltyFunction,
|
|
12
|
+
)
|
|
13
|
+
from openscvx.expert.lowering import apply_byof
|
|
14
|
+
from openscvx.expert.validation import validate_byof
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"ByofSpec",
|
|
18
|
+
"CtcsConstraintSpec",
|
|
19
|
+
"NodalConstraintSpec",
|
|
20
|
+
"PenaltyFunction",
|
|
21
|
+
"apply_byof",
|
|
22
|
+
"validate_byof",
|
|
23
|
+
]
|
openscvx/expert/byof.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
"""Bring-Your-Own-Functions (BYOF) - Expert User Mode.
|
|
2
|
+
|
|
3
|
+
This module provides type definitions and documentation for expert users who want
|
|
4
|
+
to bypass the symbolic layer and directly provide raw JAX functions.
|
|
5
|
+
|
|
6
|
+
Important:
|
|
7
|
+
The unified state/control vectors include ALL states/controls in the order
|
|
8
|
+
they were provided, plus any augmented states from CTCS constraints. You are
|
|
9
|
+
responsible for correct indexing. Consider inspecting the symbolic problem
|
|
10
|
+
to understand the layout.
|
|
11
|
+
|
|
12
|
+
Warning:
|
|
13
|
+
**Constraint Sign Convention**: All constraints follow g(x,u) <= 0 convention.
|
|
14
|
+
Return **negative when satisfied**, **positive when violated**.
|
|
15
|
+
Example: for x <= 10 return ``x - 10``, for x >= 5 return ``5 - x``.
|
|
16
|
+
|
|
17
|
+
Function Signatures:
|
|
18
|
+
All byof functions must be JAX-compatible (use jax.numpy, avoid side effects).
|
|
19
|
+
|
|
20
|
+
- dynamics: ``(x, u, node, params) -> xdot_component``
|
|
21
|
+
- x: Full unified state vector (1D array)
|
|
22
|
+
- u: Full unified control vector (1D array)
|
|
23
|
+
- node: Integer node index
|
|
24
|
+
- params: Dict of parameters
|
|
25
|
+
- Returns: State derivative component (array matching state shape)
|
|
26
|
+
|
|
27
|
+
- nodal_constraints: ``(x, u, node, params) -> residual``
|
|
28
|
+
- Same arguments as dynamics
|
|
29
|
+
- Returns: Constraint residual (g <= 0: negative=satisfied, positive=violated)
|
|
30
|
+
|
|
31
|
+
- cross_nodal_constraints: ``(X, U, params) -> residual``
|
|
32
|
+
- X: State trajectory (N, n_x) where N is number of trajectory nodes,
|
|
33
|
+
n_x is unified state dimension
|
|
34
|
+
- U: Control trajectory (N, n_u) where N is number of trajectory nodes,
|
|
35
|
+
n_u is unified control dimension
|
|
36
|
+
- params: Dict of parameters
|
|
37
|
+
- Returns: Constraint residual (g <= 0: negative=satisfied, positive=violated)
|
|
38
|
+
|
|
39
|
+
- ctcs constraint_fn: ``(x, u, node, params) -> scalar``
|
|
40
|
+
- Same as nodal_constraints but MUST return scalar
|
|
41
|
+
- Returns: Constraint residual (g <= 0: negative=satisfied, positive=violated)
|
|
42
|
+
|
|
43
|
+
- ctcs penalty: ``(residual) -> penalty_value``
|
|
44
|
+
- residual: Scalar constraint residual
|
|
45
|
+
- Returns: Non-negative penalty value
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
Basic usage mixing symbolic and byof::
|
|
49
|
+
|
|
50
|
+
import jax.numpy as jnp
|
|
51
|
+
import openscvx as ox
|
|
52
|
+
from openscvx import ByofSpec
|
|
53
|
+
|
|
54
|
+
# Define states
|
|
55
|
+
position = ox.State("position", shape=(2,))
|
|
56
|
+
velocity = ox.State("velocity", shape=(1,))
|
|
57
|
+
theta = ox.Control("theta", shape=(1,))
|
|
58
|
+
|
|
59
|
+
# Unified state: [position[0], position[1], velocity[0], time, augmented...]
|
|
60
|
+
# Unified control: [theta[0], time_dilation]
|
|
61
|
+
|
|
62
|
+
# Tip: Use the .slice property on State/Control objects for cleaner,
|
|
63
|
+
# more maintainable indexing instead of hardcoded indices.
|
|
64
|
+
byof: ByofSpec = {
|
|
65
|
+
"nodal_constraints": [
|
|
66
|
+
# Velocity bounds (applied to all nodes)
|
|
67
|
+
{
|
|
68
|
+
"constraint_fn": lambda x, u, node, params: x[velocity.slice][0] - 10.0,
|
|
69
|
+
},
|
|
70
|
+
{
|
|
71
|
+
"constraint_fn": lambda x, u, node, params: -x[velocity.slice][0],
|
|
72
|
+
},
|
|
73
|
+
# Velocity must be exactly 0 at start (selective enforcement)
|
|
74
|
+
{
|
|
75
|
+
"constraint_fn": lambda x, u, node, params: x[velocity.slice][0],
|
|
76
|
+
"nodes": [0], # Only at first node
|
|
77
|
+
},
|
|
78
|
+
],
|
|
79
|
+
"ctcs_constraints": [
|
|
80
|
+
{
|
|
81
|
+
"constraint_fn": lambda x, u, node, params: x[position.slice][0] - 10.0,
|
|
82
|
+
"penalty": "square",
|
|
83
|
+
"bounds": (0.0, 1e-4),
|
|
84
|
+
}
|
|
85
|
+
],
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
problem = ox.Problem(..., byof=byof)
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Literal, Tuple, TypedDict, Union
|
|
92
|
+
|
|
93
|
+
if TYPE_CHECKING:
|
|
94
|
+
from jax import Array as JaxArray
|
|
95
|
+
else:
|
|
96
|
+
JaxArray = Any
|
|
97
|
+
|
|
98
|
+
__all__ = ["ByofSpec", "CtcsConstraintSpec", "NodalConstraintSpec", "PenaltyFunction"]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# Type aliases for clarity
|
|
102
|
+
DynamicsFunction = Callable[[JaxArray, JaxArray, int, dict], JaxArray]
|
|
103
|
+
NodalConstraintFunction = Callable[[JaxArray, JaxArray, int, dict], JaxArray]
|
|
104
|
+
CrossNodalConstraintFunction = Callable[[JaxArray, JaxArray, dict], JaxArray]
|
|
105
|
+
CtcsConstraintFunction = Callable[[JaxArray, JaxArray, int, dict], float]
|
|
106
|
+
PenaltyFunction = Union[Literal["square", "l1", "huber"], Callable[[float], float]]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class NodalConstraintSpec(TypedDict, total=False):
|
|
110
|
+
"""Specification for nodal constraint with optional node selection.
|
|
111
|
+
|
|
112
|
+
Nodal constraints are point-wise constraints evaluated at specific trajectory nodes.
|
|
113
|
+
By default, constraints apply to all nodes, but you can restrict enforcement to
|
|
114
|
+
specific nodes for boundary conditions, waypoints, or computational efficiency.
|
|
115
|
+
|
|
116
|
+
Attributes:
|
|
117
|
+
constraint_fn: Constraint function with signature ``(x, u, node, params) -> residual``.
|
|
118
|
+
Follows g(x,u) <= 0 convention (negative = satisfied). Required field.
|
|
119
|
+
nodes: List of integer node indices where constraint is enforced.
|
|
120
|
+
If omitted, applies to all nodes. Negative indices supported (e.g., -1 for last).
|
|
121
|
+
Optional field.
|
|
122
|
+
|
|
123
|
+
Example:
|
|
124
|
+
Boundary constraint only at first and last nodes::
|
|
125
|
+
|
|
126
|
+
nodal_spec: NodalConstraintSpec = {
|
|
127
|
+
"constraint_fn": lambda x, u, node, params: x[velocity.slice][0],
|
|
128
|
+
"nodes": [0, -1], # Only at start and end
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
Waypoint constraint at middle of trajectory::
|
|
132
|
+
|
|
133
|
+
nodal_spec: NodalConstraintSpec = {
|
|
134
|
+
"constraint_fn": lambda x, u, node, params: jnp.linalg.norm(
|
|
135
|
+
x[position.slice] - jnp.array([5.0, 7.5])
|
|
136
|
+
) - 0.1,
|
|
137
|
+
"nodes": [N // 2],
|
|
138
|
+
}
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
constraint_fn: NodalConstraintFunction # Required
|
|
142
|
+
nodes: List[int]
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class CtcsConstraintSpec(TypedDict, total=False):
|
|
146
|
+
"""Specification for CTCS (Continuous-Time Constraint Satisfaction) constraint.
|
|
147
|
+
|
|
148
|
+
CTCS constraints are enforced by augmenting the dynamics with a penalty term that
|
|
149
|
+
accumulates violations over time. Useful for path constraints that must be satisfied
|
|
150
|
+
continuously, not just at discrete nodes.
|
|
151
|
+
|
|
152
|
+
Attributes:
|
|
153
|
+
constraint_fn: Function computing constraint residual with signature
|
|
154
|
+
``(x, u, node, params) -> scalar``. Must return scalar.
|
|
155
|
+
Follows g(x,u) <= 0 convention (negative = satisfied). Required field.
|
|
156
|
+
penalty: Penalty function for positive residuals (violations).
|
|
157
|
+
Built-in options: "square" (max(r,0)^2, default), "l1" (max(r,0)),
|
|
158
|
+
"huber" (Huber loss). Custom: Callable ``(r) -> penalty`` (non-negative,
|
|
159
|
+
differentiable).
|
|
160
|
+
bounds: (min, max) bounds for augmented state accumulating penalties.
|
|
161
|
+
Default: (0.0, 1e-4). Max acts as soft constraint on total violation.
|
|
162
|
+
initial: Initial value for augmented state. Default: bounds[0] (usually 0.0).
|
|
163
|
+
over: Node interval (start, end) where constraint is active. The constraint
|
|
164
|
+
is enforced for nodes in [start, end). If omitted, constraint is active
|
|
165
|
+
over all nodes. Matches symbolic `.over()` method behavior.
|
|
166
|
+
idx: Constraint group index for sharing augmented states (default: 0).
|
|
167
|
+
All CTCS constraints (symbolic and byof) with the same idx share a single
|
|
168
|
+
augmented state. Their penalties are summed together. Use different idx values
|
|
169
|
+
to track different types of violations separately.
|
|
170
|
+
|
|
171
|
+
Warning:
|
|
172
|
+
If symbolic CTCS constraints exist with idx values [0, 1, 2], then byof idx **must** either:
|
|
173
|
+
|
|
174
|
+
- Match an existing idx (e.g., 0, 1, or 2) to add to that augmented state
|
|
175
|
+
- Be sequential after them (e.g., 3, 4, 5) to create new augmented states
|
|
176
|
+
|
|
177
|
+
You cannot use idx values that create gaps (e.g., if symbolic has [0, 1],
|
|
178
|
+
you cannot use byof idx=3 without also using idx=2).
|
|
179
|
+
|
|
180
|
+
Example:
|
|
181
|
+
Enforce position[0] <= 10.0 continuously::
|
|
182
|
+
|
|
183
|
+
# Assuming position = ox.State("position", shape=(2,))
|
|
184
|
+
ctcs_spec: CtcsConstraintSpec = {
|
|
185
|
+
"constraint_fn": lambda x, u, node, params: x[position.slice][0] - 10.0,
|
|
186
|
+
"penalty": "square",
|
|
187
|
+
"bounds": (0.0, 1e-4),
|
|
188
|
+
"initial": 0.0,
|
|
189
|
+
"idx": 0, # Groups with other constraints having idx=0
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
Enforce constraint only over specific node range::
|
|
193
|
+
|
|
194
|
+
ctcs_spec: CtcsConstraintSpec = {
|
|
195
|
+
"constraint_fn": lambda x, u, node, params: x[position.slice][0] - 10.0,
|
|
196
|
+
"over": (10, 50), # Active only for nodes 10-49
|
|
197
|
+
"penalty": "square",
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
Multiple constraints sharing an augmented state::
|
|
201
|
+
|
|
202
|
+
# If symbolic CTCS already has idx=[0, 1], then:
|
|
203
|
+
|
|
204
|
+
byof = {
|
|
205
|
+
"ctcs_constraints": [
|
|
206
|
+
# Add to existing symbolic idx=0 augmented state
|
|
207
|
+
{
|
|
208
|
+
"constraint_fn": lambda x, u, node, params: x[pos.slice][0] - 10.0,
|
|
209
|
+
"idx": 0, # Shares with symbolic idx=0
|
|
210
|
+
},
|
|
211
|
+
# Add to existing symbolic idx=1 augmented state
|
|
212
|
+
{
|
|
213
|
+
"constraint_fn": lambda x, u, node, params: x[vel.slice][0] - 5.0,
|
|
214
|
+
"idx": 1, # Shares with symbolic idx=1
|
|
215
|
+
},
|
|
216
|
+
# Create NEW augmented state (sequential after symbolic)
|
|
217
|
+
{
|
|
218
|
+
"constraint_fn": lambda x, u, node, params: x[pos.slice][1] - 8.0,
|
|
219
|
+
"idx": 2, # New state (symbolic has 0,1, so next is 2)
|
|
220
|
+
},
|
|
221
|
+
]
|
|
222
|
+
}
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
constraint_fn: CtcsConstraintFunction # Required
|
|
226
|
+
penalty: PenaltyFunction
|
|
227
|
+
bounds: Tuple[float, float]
|
|
228
|
+
initial: float
|
|
229
|
+
over: Tuple[int, int]
|
|
230
|
+
idx: int
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class ByofSpec(TypedDict, total=False):
|
|
234
|
+
"""Bring-Your-Own-Functions specification for expert users.
|
|
235
|
+
|
|
236
|
+
Allows bypassing the symbolic layer and directly providing raw JAX functions.
|
|
237
|
+
All fields are optional - you can mix symbolic and byof as needed.
|
|
238
|
+
|
|
239
|
+
Warning:
|
|
240
|
+
You are responsible for:
|
|
241
|
+
|
|
242
|
+
- Correct indexing into unified state/control vectors
|
|
243
|
+
- Ensuring functions are JAX-compatible (use jax.numpy, no side effects)
|
|
244
|
+
- Ensuring functions are differentiable
|
|
245
|
+
- Following g(x,u) <= 0 convention for constraints
|
|
246
|
+
|
|
247
|
+
Tip:
|
|
248
|
+
Use the ``.slice`` property on State/Control objects for cleaner, more
|
|
249
|
+
maintainable indexing instead of hardcoded indices. For example, use
|
|
250
|
+
``x[velocity.slice]`` instead of ``x[2:3]``. The slice property is set
|
|
251
|
+
after preprocessing and provides the correct indices into the unified
|
|
252
|
+
state/control vectors.
|
|
253
|
+
|
|
254
|
+
Attributes:
|
|
255
|
+
dynamics: Raw JAX functions for state derivatives. Maps state names to functions
|
|
256
|
+
with signature ``(x, u, node, params) -> xdot_component``. States here should
|
|
257
|
+
NOT appear in symbolic dynamics dict. You can mix: some states symbolic,
|
|
258
|
+
some in byof.
|
|
259
|
+
nodal_constraints: Point-wise constraints applied at specific nodes.
|
|
260
|
+
Each item is a :class:`NodalConstraintSpec` dict with:
|
|
261
|
+
|
|
262
|
+
- ``func``: Constraint function ``(x, u, node, params) -> residual`` (required)
|
|
263
|
+
- ``nodes``: List of node indices (optional, defaults to all nodes)
|
|
264
|
+
|
|
265
|
+
Follows g(x,u) <= 0 convention.
|
|
266
|
+
cross_nodal_constraints: Constraints coupling multiple nodes (smoothness, rate limits).
|
|
267
|
+
Signature: ``(X, U, params) -> residual`` where X is (N, n_x) and U is (N, n_u).
|
|
268
|
+
N is the number of trajectory nodes, n_x is state dimension, n_u is control dimension.
|
|
269
|
+
Follows g(X,U) <= 0 convention.
|
|
270
|
+
ctcs_constraints: Continuous-time constraint satisfaction via dynamics augmentation.
|
|
271
|
+
Each adds an augmented state accumulating violation penalties.
|
|
272
|
+
See :class:`CtcsConstraintSpec` for details.
|
|
273
|
+
|
|
274
|
+
Example:
|
|
275
|
+
Custom dynamics and constraints::
|
|
276
|
+
|
|
277
|
+
import jax.numpy as jnp
|
|
278
|
+
import openscvx as ox
|
|
279
|
+
from openscvx import ByofSpec
|
|
280
|
+
|
|
281
|
+
# Define states and controls
|
|
282
|
+
position = ox.State("position", shape=(2,))
|
|
283
|
+
velocity = ox.State("velocity", shape=(1,))
|
|
284
|
+
theta = ox.Control("theta", shape=(1,))
|
|
285
|
+
|
|
286
|
+
# Custom dynamics for one state using .slice property
|
|
287
|
+
def custom_velocity_dynamics(x, u, node, params):
|
|
288
|
+
# Use .slice property for clean indexing
|
|
289
|
+
return params["g"] * jnp.cos(u[theta.slice][0])
|
|
290
|
+
|
|
291
|
+
byof: ByofSpec = {
|
|
292
|
+
"dynamics": {
|
|
293
|
+
"velocity": custom_velocity_dynamics,
|
|
294
|
+
},
|
|
295
|
+
"nodal_constraints": [
|
|
296
|
+
# Applied to all nodes (no "nodes" field)
|
|
297
|
+
{
|
|
298
|
+
"constraint_fn": lambda x, u, node, params: x[velocity.slice][0] - 10.0,
|
|
299
|
+
},
|
|
300
|
+
{
|
|
301
|
+
"constraint_fn": lambda x, u, node, params: -x[velocity.slice][0],
|
|
302
|
+
},
|
|
303
|
+
# Specify nodes for selective enforcement
|
|
304
|
+
{
|
|
305
|
+
"constraint_fn": lambda x, u, node, params: x[velocity.slice][0],
|
|
306
|
+
"nodes": [0], # Velocity must be exactly 0 at start
|
|
307
|
+
},
|
|
308
|
+
],
|
|
309
|
+
"cross_nodal_constraints": [
|
|
310
|
+
# Constrain total velocity across trajectory: sum(velocities) >= 5
|
|
311
|
+
# X.shape = (N, n_x), extract velocity column using slice
|
|
312
|
+
lambda X, U, params: 5.0 - jnp.sum(X[:, velocity.slice]),
|
|
313
|
+
],
|
|
314
|
+
"ctcs_constraints": [
|
|
315
|
+
{
|
|
316
|
+
"constraint_fn": lambda x, u, node, params: x[position.slice][0] - 5.0,
|
|
317
|
+
"penalty": "square",
|
|
318
|
+
}
|
|
319
|
+
],
|
|
320
|
+
}
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
dynamics: dict[str, DynamicsFunction]
|
|
324
|
+
nodal_constraints: List[NodalConstraintSpec]
|
|
325
|
+
cross_nodal_constraints: List[CrossNodalConstraintFunction]
|
|
326
|
+
ctcs_constraints: List[CtcsConstraintSpec]
|