openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
"""Validation for bring-your-own-functions (byof).
|
|
2
|
+
|
|
3
|
+
This module provides validation for user-provided JAX functions in expert mode,
|
|
4
|
+
checking signatures, shapes, and differentiability before use.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import inspect
|
|
8
|
+
from typing import TYPE_CHECKING, List
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from openscvx.symbolic.expr.state import State
|
|
12
|
+
|
|
13
|
+
__all__ = ["validate_byof"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def validate_byof(
|
|
17
|
+
byof: dict,
|
|
18
|
+
states: List["State"],
|
|
19
|
+
n_x: int,
|
|
20
|
+
n_u: int,
|
|
21
|
+
N: int = None,
|
|
22
|
+
) -> None:
|
|
23
|
+
"""Validate byof function signatures and shapes.
|
|
24
|
+
|
|
25
|
+
Checks that user-provided functions have the correct signatures and return
|
|
26
|
+
appropriate shapes. Performs validation before functions are used to provide
|
|
27
|
+
clear error messages.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
byof: Dictionary of user-provided functions to validate
|
|
31
|
+
states: List of State objects for determining expected shapes
|
|
32
|
+
n_x: Total dimension of the unified state vector
|
|
33
|
+
n_u: Total dimension of the unified control vector
|
|
34
|
+
N: Number of nodes in the trajectory (optional). If provided, validates
|
|
35
|
+
node indices in nodal constraints.
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
ValueError: If any function has invalid signature or returns wrong shape
|
|
39
|
+
TypeError: If functions are not callable
|
|
40
|
+
|
|
41
|
+
Example:
|
|
42
|
+
>>> validate_byof(byof, states, n_x=10, n_u=3, N=50) # Raises if invalid
|
|
43
|
+
"""
|
|
44
|
+
import jax
|
|
45
|
+
import jax.numpy as jnp
|
|
46
|
+
|
|
47
|
+
# Validate byof keys
|
|
48
|
+
valid_keys = {"dynamics", "nodal_constraints", "cross_nodal_constraints", "ctcs_constraints"}
|
|
49
|
+
invalid_keys = set(byof.keys()) - valid_keys
|
|
50
|
+
if invalid_keys:
|
|
51
|
+
raise ValueError(f"Unknown byof keys: {invalid_keys}. Valid keys: {valid_keys}")
|
|
52
|
+
|
|
53
|
+
# Create dummy inputs for testing
|
|
54
|
+
dummy_x = jnp.zeros(n_x)
|
|
55
|
+
dummy_u = jnp.zeros(n_u)
|
|
56
|
+
dummy_node = 0
|
|
57
|
+
dummy_params = {}
|
|
58
|
+
|
|
59
|
+
# Validate dynamics functions
|
|
60
|
+
byof_dynamics = byof.get("dynamics", {})
|
|
61
|
+
if byof_dynamics:
|
|
62
|
+
# Build mapping from state name to expected shape
|
|
63
|
+
state_shapes = {state.name: state.shape for state in states}
|
|
64
|
+
|
|
65
|
+
for state_name, fn in byof_dynamics.items():
|
|
66
|
+
if state_name not in state_shapes:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"byof dynamics '{state_name}' does not match any state name. "
|
|
69
|
+
f"Available states: {list(state_shapes.keys())}"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
if not callable(fn):
|
|
73
|
+
raise TypeError(f"byof dynamics '{state_name}' must be callable, got {type(fn)}")
|
|
74
|
+
|
|
75
|
+
# Check signature
|
|
76
|
+
sig = inspect.signature(fn)
|
|
77
|
+
if len(sig.parameters) != 4:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"byof dynamics '{state_name}' must have signature f(x, u, node, params), "
|
|
80
|
+
f"got {len(sig.parameters)} parameters: {list(sig.parameters.keys())}"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Test call and check output shape
|
|
84
|
+
try:
|
|
85
|
+
result = fn(dummy_x, dummy_u, dummy_node, dummy_params)
|
|
86
|
+
except Exception as e:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"byof dynamics '{state_name}' failed on test call with "
|
|
89
|
+
f"x.shape={dummy_x.shape}, u.shape={dummy_u.shape}: {e}"
|
|
90
|
+
) from e
|
|
91
|
+
|
|
92
|
+
expected_shape = state_shapes[state_name]
|
|
93
|
+
result_shape = jnp.asarray(result).shape
|
|
94
|
+
if result_shape != expected_shape:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"byof dynamics '{state_name}' returned shape {result_shape}, "
|
|
97
|
+
f"expected {expected_shape} (state '{state_name}' shape)"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Test that gradient works (JAX compatibility check)
|
|
101
|
+
try:
|
|
102
|
+
jax.grad(lambda x: jnp.sum(fn(x, dummy_u, dummy_node, dummy_params)))(dummy_x)
|
|
103
|
+
except Exception as e:
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f"byof dynamics '{state_name}' is not differentiable with JAX. "
|
|
106
|
+
f"Ensure the function uses JAX operations (jax.numpy, not numpy): {e}"
|
|
107
|
+
) from e
|
|
108
|
+
|
|
109
|
+
# Validate nodal constraints
|
|
110
|
+
for i, constraint_spec in enumerate(byof.get("nodal_constraints", [])):
|
|
111
|
+
if not isinstance(constraint_spec, dict):
|
|
112
|
+
raise TypeError(
|
|
113
|
+
f"byof nodal_constraints[{i}] must be a dict (NodalConstraintSpec), "
|
|
114
|
+
f"got {type(constraint_spec)}"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if "constraint_fn" not in constraint_spec:
|
|
118
|
+
raise ValueError(f"byof nodal_constraints[{i}] missing required key 'constraint_fn'")
|
|
119
|
+
|
|
120
|
+
fn = constraint_spec["constraint_fn"]
|
|
121
|
+
if not callable(fn):
|
|
122
|
+
raise TypeError(
|
|
123
|
+
f"byof nodal_constraints[{i}]['constraint_fn'] must be callable, got {type(fn)}"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Check signature
|
|
127
|
+
sig = inspect.signature(fn)
|
|
128
|
+
if len(sig.parameters) != 4:
|
|
129
|
+
raise ValueError(
|
|
130
|
+
f"byof nodal_constraints[{i}]['constraint_fn'] must have signature "
|
|
131
|
+
f"f(x, u, node, params), "
|
|
132
|
+
f"got {len(sig.parameters)} parameters: {list(sig.parameters.keys())}"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Test call
|
|
136
|
+
try:
|
|
137
|
+
result = fn(dummy_x, dummy_u, dummy_node, dummy_params)
|
|
138
|
+
except Exception as e:
|
|
139
|
+
raise ValueError(
|
|
140
|
+
f"byof nodal_constraints[{i}]['constraint_fn'] failed on test call with "
|
|
141
|
+
f"x.shape={dummy_x.shape}, u.shape={dummy_u.shape}: {e}"
|
|
142
|
+
) from e
|
|
143
|
+
|
|
144
|
+
# Check that result is array-like (can be scalar or vector)
|
|
145
|
+
try:
|
|
146
|
+
result_array = jnp.asarray(result)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f"byof nodal_constraints[{i}]['constraint_fn'] must return array-like value, "
|
|
150
|
+
f"got {type(result)}: {e}"
|
|
151
|
+
) from e
|
|
152
|
+
|
|
153
|
+
# Test gradient
|
|
154
|
+
try:
|
|
155
|
+
jax.grad(lambda x: jnp.sum(fn(x, dummy_u, dummy_node, dummy_params)))(dummy_x)
|
|
156
|
+
except Exception as e:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"byof nodal_constraints[{i}]['constraint_fn'] is not differentiable with JAX: {e}"
|
|
159
|
+
) from e
|
|
160
|
+
|
|
161
|
+
# Validate nodes if provided
|
|
162
|
+
if "nodes" in constraint_spec:
|
|
163
|
+
nodes = constraint_spec["nodes"]
|
|
164
|
+
if not isinstance(nodes, (list, tuple)):
|
|
165
|
+
raise TypeError(
|
|
166
|
+
f"byof nodal_constraints[{i}]['nodes'] must be a list or tuple, "
|
|
167
|
+
f"got {type(nodes)}"
|
|
168
|
+
)
|
|
169
|
+
if len(nodes) == 0:
|
|
170
|
+
raise ValueError(f"byof nodal_constraints[{i}]['nodes'] cannot be empty")
|
|
171
|
+
|
|
172
|
+
# Validate node indices if N is provided
|
|
173
|
+
if N is not None:
|
|
174
|
+
for node in nodes:
|
|
175
|
+
# Handle negative indices (e.g., -1 for last node)
|
|
176
|
+
normalized_node = node if node >= 0 else N + node
|
|
177
|
+
# Validate range
|
|
178
|
+
if not (0 <= normalized_node < N):
|
|
179
|
+
raise ValueError(
|
|
180
|
+
f"byof nodal_constraints[{i}]['nodes'] contains invalid index {node} "
|
|
181
|
+
f"(normalized: {normalized_node}). Valid range is [0, {N}) or "
|
|
182
|
+
f"negative indices [-{N}, -1]."
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Validate cross-nodal constraints
|
|
186
|
+
dummy_X = jnp.zeros((10, n_x)) # Dummy trajectory with 10 nodes
|
|
187
|
+
dummy_U = jnp.zeros((10, n_u))
|
|
188
|
+
|
|
189
|
+
for i, fn in enumerate(byof.get("cross_nodal_constraints", [])):
|
|
190
|
+
if not callable(fn):
|
|
191
|
+
raise TypeError(f"byof cross_nodal_constraints[{i}] must be callable, got {type(fn)}")
|
|
192
|
+
|
|
193
|
+
# Check signature
|
|
194
|
+
sig = inspect.signature(fn)
|
|
195
|
+
if len(sig.parameters) != 3:
|
|
196
|
+
raise ValueError(
|
|
197
|
+
f"byof cross_nodal_constraints[{i}] must have signature f(X, U, params), "
|
|
198
|
+
f"got {len(sig.parameters)} parameters: {list(sig.parameters.keys())}"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Test call
|
|
202
|
+
try:
|
|
203
|
+
result = fn(dummy_X, dummy_U, dummy_params)
|
|
204
|
+
except Exception as e:
|
|
205
|
+
raise ValueError(
|
|
206
|
+
f"byof cross_nodal_constraints[{i}] failed on test call with "
|
|
207
|
+
f"X.shape={dummy_X.shape}, U.shape={dummy_U.shape}: {e}"
|
|
208
|
+
) from e
|
|
209
|
+
|
|
210
|
+
# Check that result is array-like
|
|
211
|
+
try:
|
|
212
|
+
result_array = jnp.asarray(result)
|
|
213
|
+
except Exception as e:
|
|
214
|
+
raise ValueError(
|
|
215
|
+
f"byof cross_nodal_constraints[{i}] must return array-like value, "
|
|
216
|
+
f"got {type(result)}: {e}"
|
|
217
|
+
) from e
|
|
218
|
+
|
|
219
|
+
# Test gradient
|
|
220
|
+
try:
|
|
221
|
+
jax.grad(lambda X: jnp.sum(fn(X, dummy_U, dummy_params)))(dummy_X)
|
|
222
|
+
except Exception as e:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
f"byof cross_nodal_constraints[{i}] is not differentiable with JAX: {e}"
|
|
225
|
+
) from e
|
|
226
|
+
|
|
227
|
+
# Validate CTCS constraints
|
|
228
|
+
for i, ctcs_spec in enumerate(byof.get("ctcs_constraints", [])):
|
|
229
|
+
if not isinstance(ctcs_spec, dict):
|
|
230
|
+
raise TypeError(f"byof ctcs_constraints[{i}] must be a dict, got {type(ctcs_spec)}")
|
|
231
|
+
|
|
232
|
+
if "constraint_fn" not in ctcs_spec:
|
|
233
|
+
raise ValueError(f"byof ctcs_constraints[{i}] missing required key 'constraint_fn'")
|
|
234
|
+
|
|
235
|
+
fn = ctcs_spec["constraint_fn"]
|
|
236
|
+
if not callable(fn):
|
|
237
|
+
raise TypeError(
|
|
238
|
+
f"byof ctcs_constraints[{i}]['constraint_fn'] must be callable, got {type(fn)}"
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Check signature
|
|
242
|
+
sig = inspect.signature(fn)
|
|
243
|
+
if len(sig.parameters) != 4:
|
|
244
|
+
raise ValueError(
|
|
245
|
+
f"byof ctcs_constraints[{i}]['constraint_fn'] must have signature "
|
|
246
|
+
f"f(x, u, node, params), got {len(sig.parameters)} parameters: "
|
|
247
|
+
f"{list(sig.parameters.keys())}"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# Test call
|
|
251
|
+
try:
|
|
252
|
+
result = fn(dummy_x, dummy_u, dummy_node, dummy_params)
|
|
253
|
+
except Exception as e:
|
|
254
|
+
raise ValueError(
|
|
255
|
+
f"byof ctcs_constraints[{i}]['constraint_fn'] failed on test call: {e}"
|
|
256
|
+
) from e
|
|
257
|
+
|
|
258
|
+
# Check that result is scalar
|
|
259
|
+
result_array = jnp.asarray(result)
|
|
260
|
+
if result_array.shape != ():
|
|
261
|
+
raise ValueError(
|
|
262
|
+
f"byof ctcs_constraints[{i}]['constraint_fn'] must return a scalar, "
|
|
263
|
+
f"got shape {result_array.shape}"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Test gradient
|
|
267
|
+
try:
|
|
268
|
+
jax.grad(lambda x: fn(x, dummy_u, dummy_node, dummy_params))(dummy_x)
|
|
269
|
+
except Exception as e:
|
|
270
|
+
raise ValueError(
|
|
271
|
+
f"byof ctcs_constraints[{i}]['constraint_fn'] is not differentiable with JAX: {e}"
|
|
272
|
+
) from e
|
|
273
|
+
|
|
274
|
+
# Validate penalty function if provided
|
|
275
|
+
if "penalty" in ctcs_spec:
|
|
276
|
+
penalty_spec = ctcs_spec["penalty"]
|
|
277
|
+
if callable(penalty_spec):
|
|
278
|
+
# Test custom penalty function
|
|
279
|
+
try:
|
|
280
|
+
test_residual = jnp.array(0.5)
|
|
281
|
+
penalty_result = penalty_spec(test_residual)
|
|
282
|
+
jnp.asarray(penalty_result)
|
|
283
|
+
except Exception as e:
|
|
284
|
+
raise ValueError(
|
|
285
|
+
f"byof ctcs_constraints[{i}]['penalty'] custom function failed: {e}"
|
|
286
|
+
) from e
|
|
287
|
+
elif penalty_spec not in ["square", "l1", "huber"]:
|
|
288
|
+
raise ValueError(
|
|
289
|
+
f"byof ctcs_constraints[{i}]['penalty'] must be 'square', 'l1', 'huber', "
|
|
290
|
+
f"or a callable, got {penalty_spec!r}"
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# Validate idx if provided
|
|
294
|
+
if "idx" in ctcs_spec:
|
|
295
|
+
idx = ctcs_spec["idx"]
|
|
296
|
+
if not isinstance(idx, int):
|
|
297
|
+
raise TypeError(
|
|
298
|
+
f"byof ctcs_constraints[{i}]['idx'] must be an integer, got {type(idx)}"
|
|
299
|
+
)
|
|
300
|
+
if idx < 0:
|
|
301
|
+
raise ValueError(
|
|
302
|
+
f"byof ctcs_constraints[{i}]['idx'] must be non-negative, got {idx}"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# Validate bounds if provided
|
|
306
|
+
if "bounds" in ctcs_spec:
|
|
307
|
+
bounds = ctcs_spec["bounds"]
|
|
308
|
+
if not isinstance(bounds, (tuple, list)) or len(bounds) != 2:
|
|
309
|
+
raise ValueError(
|
|
310
|
+
f"byof ctcs_constraints[{i}]['bounds'] must be a (min, max) tuple, got {bounds}"
|
|
311
|
+
)
|
|
312
|
+
if bounds[0] > bounds[1]:
|
|
313
|
+
raise ValueError(
|
|
314
|
+
f"byof ctcs_constraints[{i}]['bounds'] min ({bounds[0]}) must be <= "
|
|
315
|
+
f"max ({bounds[1]})"
|
|
316
|
+
)
|
|
317
|
+
else:
|
|
318
|
+
# Use default bounds for initial value validation
|
|
319
|
+
bounds = (0.0, 1e-4)
|
|
320
|
+
|
|
321
|
+
# Validate initial value is within bounds
|
|
322
|
+
if "initial" in ctcs_spec:
|
|
323
|
+
initial = ctcs_spec["initial"]
|
|
324
|
+
if not (bounds[0] <= initial <= bounds[1]):
|
|
325
|
+
raise ValueError(
|
|
326
|
+
f"byof ctcs_constraints[{i}]['initial'] ({initial}) must be within "
|
|
327
|
+
f"bounds [{bounds[0]}, {bounds[1]}]"
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Validate over (node interval) if provided
|
|
331
|
+
if "over" in ctcs_spec:
|
|
332
|
+
over = ctcs_spec["over"]
|
|
333
|
+
if not isinstance(over, (tuple, list)) or len(over) != 2:
|
|
334
|
+
raise ValueError(
|
|
335
|
+
f"byof ctcs_constraints[{i}]['over'] must be a (start, end) tuple, got {over}"
|
|
336
|
+
)
|
|
337
|
+
start, end = over
|
|
338
|
+
if not isinstance(start, int) or not isinstance(end, int):
|
|
339
|
+
raise TypeError(
|
|
340
|
+
f"byof ctcs_constraints[{i}]['over'] indices must be integers, "
|
|
341
|
+
f"got start={type(start)}, end={type(end)}"
|
|
342
|
+
)
|
|
343
|
+
if start >= end:
|
|
344
|
+
raise ValueError(
|
|
345
|
+
f"byof ctcs_constraints[{i}]['over'] start ({start}) must be < end ({end})"
|
|
346
|
+
)
|
|
347
|
+
if start < 0:
|
|
348
|
+
raise ValueError(
|
|
349
|
+
f"byof ctcs_constraints[{i}]['over'] start ({start}) must be non-negative"
|
|
350
|
+
)
|
|
351
|
+
# Validate against trajectory length if N is provided
|
|
352
|
+
if N is not None:
|
|
353
|
+
if end > N:
|
|
354
|
+
raise ValueError(
|
|
355
|
+
f"byof ctcs_constraints[{i}]['over'] end ({end}) exceeds "
|
|
356
|
+
f"trajectory length ({N})"
|
|
357
|
+
)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Numerical integration schemes for trajectory optimization.
|
|
2
|
+
|
|
3
|
+
This module provides implementations of numerical integrators used for simulating
|
|
4
|
+
continuous-time dynamics.
|
|
5
|
+
|
|
6
|
+
Current Implementations:
|
|
7
|
+
RK45 Integration: Explicit Runge-Kutta-Fehlberg method (4th/5th order)
|
|
8
|
+
with both fixed-step and adaptive implementations via Diffrax.
|
|
9
|
+
Supports a variety of explicit and implicit ODE solvers through the
|
|
10
|
+
Diffrax backend (Dopri5/8, Tsit5, KenCarp3/4/5, etc.).
|
|
11
|
+
|
|
12
|
+
Planned Architecture (ABC-based):
|
|
13
|
+
|
|
14
|
+
A base class will be introduced to enable pluggable integrator implementations.
|
|
15
|
+
This will enable users to implement custom integrators.
|
|
16
|
+
Future integrators will implement the Integrator interface:
|
|
17
|
+
|
|
18
|
+
```python
|
|
19
|
+
# integrators/base.py (planned):
|
|
20
|
+
class Integrator(ABC):
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def step(self, f: Callable, x: Array, u: Array, t: float, dt: float) -> Array:
|
|
23
|
+
'''Take one integration step from state x at time t with step dt.'''
|
|
24
|
+
...
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def integrate(self, f: Callable, x0: Array, u_traj: Array,
|
|
28
|
+
t_span: tuple[float, float], num_steps: int) -> Array:
|
|
29
|
+
'''Integrate over a time span with given control trajectory.'''
|
|
30
|
+
...
|
|
31
|
+
```
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
from .runge_kutta import (
|
|
35
|
+
SOLVER_MAP,
|
|
36
|
+
rk45_step,
|
|
37
|
+
solve_ivp_diffrax,
|
|
38
|
+
solve_ivp_diffrax_prop,
|
|
39
|
+
solve_ivp_rk45,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
__all__ = [
|
|
43
|
+
"SOLVER_MAP",
|
|
44
|
+
"rk45_step",
|
|
45
|
+
"solve_ivp_rk45",
|
|
46
|
+
"solve_ivp_diffrax",
|
|
47
|
+
"solve_ivp_diffrax_prop",
|
|
48
|
+
]
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any, Callable
|
|
3
|
+
|
|
4
|
+
import diffrax as dfx
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
from diffrax._global_interpolation import DenseInterpolation
|
|
8
|
+
from jax import tree_util
|
|
9
|
+
|
|
10
|
+
os.environ["EQX_ON_ERROR"] = "nan"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# Safely check if DenseInterpolation is already registered
|
|
14
|
+
try:
|
|
15
|
+
# Attempt to flatten a dummy DenseInterpolation instance
|
|
16
|
+
# Provide dummy arguments to create a valid instance
|
|
17
|
+
dummy_instance = DenseInterpolation(
|
|
18
|
+
ts=jnp.array([]),
|
|
19
|
+
ts_size=0,
|
|
20
|
+
infos=None,
|
|
21
|
+
interpolation_cls=None,
|
|
22
|
+
direction=None,
|
|
23
|
+
t0_if_trivial=0.0,
|
|
24
|
+
y0_if_trivial=jnp.array([]),
|
|
25
|
+
)
|
|
26
|
+
tree_util.tree_flatten(dummy_instance)
|
|
27
|
+
except ValueError:
|
|
28
|
+
# Register DenseInterpolation as a PyTree node if not already registered
|
|
29
|
+
def dense_interpolation_flatten(obj):
|
|
30
|
+
# Flatten the internal data of DenseInterpolation
|
|
31
|
+
return (obj._data,), None
|
|
32
|
+
|
|
33
|
+
def dense_interpolation_unflatten(aux_data, children):
|
|
34
|
+
# Reconstruct DenseInterpolation from its flattened data
|
|
35
|
+
return DenseInterpolation(*children)
|
|
36
|
+
|
|
37
|
+
tree_util.register_pytree_node(
|
|
38
|
+
DenseInterpolation,
|
|
39
|
+
dense_interpolation_flatten,
|
|
40
|
+
dense_interpolation_unflatten,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
SOLVER_MAP = {
|
|
44
|
+
"Tsit5": dfx.Tsit5,
|
|
45
|
+
"Euler": dfx.Euler,
|
|
46
|
+
"Heun": dfx.Heun,
|
|
47
|
+
"Midpoint": dfx.Midpoint,
|
|
48
|
+
"Ralston": dfx.Ralston,
|
|
49
|
+
"Dopri5": dfx.Dopri5,
|
|
50
|
+
"Dopri8": dfx.Dopri8,
|
|
51
|
+
"Bosh3": dfx.Bosh3,
|
|
52
|
+
"ReversibleHeun": dfx.ReversibleHeun,
|
|
53
|
+
"ImplicitEuler": dfx.ImplicitEuler,
|
|
54
|
+
"KenCarp3": dfx.KenCarp3,
|
|
55
|
+
"KenCarp4": dfx.KenCarp4,
|
|
56
|
+
"KenCarp5": dfx.KenCarp5,
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# fmt: off
|
|
61
|
+
def rk45_step(
|
|
62
|
+
f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
|
|
63
|
+
t: jnp.ndarray,
|
|
64
|
+
y: jnp.ndarray,
|
|
65
|
+
h: float,
|
|
66
|
+
*args
|
|
67
|
+
) -> jnp.ndarray:
|
|
68
|
+
"""
|
|
69
|
+
Perform a single RK45 (Runge-Kutta-Fehlberg) integration step.
|
|
70
|
+
|
|
71
|
+
This implements the classic Dorman-Prince coefficients for an
|
|
72
|
+
explicit 4(5) method, returning the fourth-order estimate.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
f (Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]):
|
|
76
|
+
ODE right-hand side; signature f(t, y, *args) -> dy/dt.
|
|
77
|
+
t (jnp.ndarray): Current time.
|
|
78
|
+
y (jnp.ndarray): Current state vector.
|
|
79
|
+
h (float): Step size.
|
|
80
|
+
*args: Additional arguments passed to `f`.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
jnp.ndarray: Next state estimate at t + h.
|
|
84
|
+
"""
|
|
85
|
+
k1 = f(t, y, *args)
|
|
86
|
+
k2 = f(t + h/4, y + h*k1/4, *args)
|
|
87
|
+
k3 = f(t + 3*h/8, y + 3*h*k1/32 + 9*h*k2/32, *args)
|
|
88
|
+
k4 = f(t + 12*h/13, y + 1932*h*k1/2197 - 7200*h*k2/2197 + 7296*h*k3/2197, *args)
|
|
89
|
+
k5 = f(t + h, y + 439*h*k1/216 - 8*h*k2 + 3680*h*k3/513 - 845*h*k4/4104, *args)
|
|
90
|
+
y_next = y + h * (25*k1/216 + 1408*k3/2565 + 2197*k4/4104 - k5/5)
|
|
91
|
+
return y_next
|
|
92
|
+
# fmt: on
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def solve_ivp_rk45(
|
|
96
|
+
f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
|
|
97
|
+
tau_final: float,
|
|
98
|
+
y_0: jnp.ndarray,
|
|
99
|
+
args,
|
|
100
|
+
tau_0: float = 0.0,
|
|
101
|
+
num_substeps: int = 50,
|
|
102
|
+
is_not_compiled: bool = False,
|
|
103
|
+
):
|
|
104
|
+
"""
|
|
105
|
+
Solve an initial-value ODE problem using fixed-step RK45 integration.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
f (Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]):
|
|
109
|
+
ODE right-hand side; signature f(t, y, *args) -> dy/dt.
|
|
110
|
+
tau_final (float): Final integration time.
|
|
111
|
+
y_0 (jnp.ndarray): Initial state at tau_0.
|
|
112
|
+
args (tuple): Extra arguments to pass to `f`.
|
|
113
|
+
tau_0 (float, optional): Initial time. Defaults to 0.0.
|
|
114
|
+
num_substeps (int, optional): Number of output time points. Defaults to 50.
|
|
115
|
+
is_not_compiled (bool, optional): If True, use Python loop instead of
|
|
116
|
+
JAX `lax.fori_loop`. Defaults to False.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
jnp.ndarray: Array of shape (num_substeps, state_dim) with solution at each time.
|
|
120
|
+
"""
|
|
121
|
+
substeps = jnp.linspace(tau_0, tau_final, num_substeps)
|
|
122
|
+
|
|
123
|
+
h = (tau_final - tau_0) / (len(substeps) - 1)
|
|
124
|
+
solution = jnp.zeros((len(substeps), len(y_0)))
|
|
125
|
+
solution = solution.at[0].set(y_0)
|
|
126
|
+
|
|
127
|
+
if is_not_compiled:
|
|
128
|
+
for i in range(1, len(substeps)):
|
|
129
|
+
t = tau_0 + i * h
|
|
130
|
+
solution = solution.at[i].set(rk45_step(f, t, solution[i - 1], h, *args))
|
|
131
|
+
else:
|
|
132
|
+
|
|
133
|
+
def body_fun(i, val):
|
|
134
|
+
t, y, V_result = val
|
|
135
|
+
y_next = rk45_step(f, t, y, h, *args)
|
|
136
|
+
V_result = V_result.at[i].set(y_next)
|
|
137
|
+
return (t + h, y_next, V_result)
|
|
138
|
+
|
|
139
|
+
_, _, solution = jax.lax.fori_loop(1, len(substeps), body_fun, (tau_0, y_0, solution))
|
|
140
|
+
|
|
141
|
+
return solution
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def solve_ivp_diffrax(
|
|
145
|
+
f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
|
|
146
|
+
tau_final: float,
|
|
147
|
+
y_0: jnp.ndarray,
|
|
148
|
+
args,
|
|
149
|
+
tau_0: float = 0.0,
|
|
150
|
+
num_substeps: int = 50,
|
|
151
|
+
solver_name: str = "Dopri8",
|
|
152
|
+
rtol: float = 1e-3,
|
|
153
|
+
atol: float = 1e-6,
|
|
154
|
+
extra_kwargs=None,
|
|
155
|
+
):
|
|
156
|
+
"""
|
|
157
|
+
Solve an initial-value ODE problem using a Diffrax adaptive solver.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
f (Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]):
|
|
161
|
+
ODE right-hand side; f(t, y, *args).
|
|
162
|
+
tau_final (float): Final integration time.
|
|
163
|
+
y_0 (jnp.ndarray): Initial state at tau_0.
|
|
164
|
+
args (tuple): Extra arguments to pass to `f` in the solver term.
|
|
165
|
+
tau_0 (float, optional): Initial time. Defaults to 0.0.
|
|
166
|
+
num_substeps (int, optional): Number of save points between tau_0 and tau_final.
|
|
167
|
+
Defaults to 50.
|
|
168
|
+
solver_name (str, optional): Key into SOLVER_MAP for the Diffrax solver class.
|
|
169
|
+
Defaults to "Dopri8".
|
|
170
|
+
rtol (float, optional): Relative tolerance for adaptive stepping. Defaults to 1e-3.
|
|
171
|
+
atol (float, optional): Absolute tolerance for adaptive stepping. Defaults to 1e-6.
|
|
172
|
+
extra_kwargs (dict, optional): Additional keyword arguments forwarded to `diffeqsolve`.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
jnp.ndarray: Solution states at the requested save points, shape (num_substeps, state_dim).
|
|
176
|
+
|
|
177
|
+
Raises:
|
|
178
|
+
ValueError: If `solver_name` is not in SOLVER_MAP.
|
|
179
|
+
"""
|
|
180
|
+
substeps = jnp.linspace(tau_0, tau_final, num_substeps)
|
|
181
|
+
|
|
182
|
+
solver_class = SOLVER_MAP.get(solver_name)
|
|
183
|
+
if solver_class is None:
|
|
184
|
+
raise ValueError(f"Unknown solver: {solver_name}")
|
|
185
|
+
solver = solver_class()
|
|
186
|
+
|
|
187
|
+
term = dfx.ODETerm(lambda t, y, args: f(t, y, *args))
|
|
188
|
+
stepsize_controller = dfx.PIDController(rtol=rtol, atol=atol)
|
|
189
|
+
solution = dfx.diffeqsolve(
|
|
190
|
+
term,
|
|
191
|
+
solver=solver,
|
|
192
|
+
t0=tau_0,
|
|
193
|
+
t1=tau_final,
|
|
194
|
+
dt0=(tau_final - tau_0) / (len(substeps) - 1),
|
|
195
|
+
y0=y_0,
|
|
196
|
+
args=args,
|
|
197
|
+
stepsize_controller=stepsize_controller,
|
|
198
|
+
saveat=dfx.SaveAt(ts=substeps),
|
|
199
|
+
progress_meter=dfx.NoProgressMeter(),
|
|
200
|
+
**(extra_kwargs or {}),
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
return solution.ys
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def solve_ivp_diffrax_prop(
|
|
207
|
+
f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
|
|
208
|
+
tau_final: float,
|
|
209
|
+
y_0: jnp.ndarray,
|
|
210
|
+
args,
|
|
211
|
+
tau_0: float = 0.0,
|
|
212
|
+
num_substeps: int = 50,
|
|
213
|
+
solver_name: str = "Dopri8",
|
|
214
|
+
rtol: float = 1e-3,
|
|
215
|
+
atol: float = 1e-6,
|
|
216
|
+
extra_kwargs=None,
|
|
217
|
+
save_time: jnp.ndarray = None,
|
|
218
|
+
mask: jnp.ndarray = None,
|
|
219
|
+
):
|
|
220
|
+
"""
|
|
221
|
+
Solve an initial-value ODE problem using a Diffrax adaptive solver.
|
|
222
|
+
This function is specifically designed for use in the context of
|
|
223
|
+
trajectory optimization and handles the nonlinear single-shot propagation
|
|
224
|
+
of state variables in undilated time.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
f (Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]): ODE right-hand side;
|
|
228
|
+
signature f(t, y, *args) -> dy/dt.
|
|
229
|
+
tau_final (float): Final integration time.
|
|
230
|
+
y_0 (jnp.ndarray): Initial state at tau_0.
|
|
231
|
+
args (tuple): Extra arguments to pass to `f` in the solver term.
|
|
232
|
+
tau_0 (float, optional): Initial time. Defaults to 0.0.
|
|
233
|
+
num_substeps (int, optional): Number of save points between tau_0 and tau_final.
|
|
234
|
+
Defaults to 50.
|
|
235
|
+
solver_name (str, optional): Key into SOLVER_MAP for the Diffrax solver class.
|
|
236
|
+
Defaults to "Dopri8".
|
|
237
|
+
rtol (float, optional): Relative tolerance for adaptive stepping. Defaults to 1e-3.
|
|
238
|
+
atol (float, optional): Absolute tolerance for adaptive stepping. Defaults to 1e-6.
|
|
239
|
+
extra_kwargs (dict, optional): Additional keyword arguments forwarded to `diffeqsolve`.
|
|
240
|
+
save_time (jnp.ndarray, optional): Time points at which to evaluate the solution.
|
|
241
|
+
Must be provided for export compatibility.
|
|
242
|
+
mask (jnp.ndarray, optional): Boolean mask for the save_time points.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
jnp.ndarray: Solution states at the requested save points, shape (num_substeps, state_dim).
|
|
246
|
+
Raises:
|
|
247
|
+
ValueError: If `solver_name` is not in SOLVER_MAP or if save_time is not provided.
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
if save_time is None:
|
|
251
|
+
raise ValueError("save_time must be provided for export compatibility.")
|
|
252
|
+
if mask is None:
|
|
253
|
+
mask = jnp.ones_like(save_time, dtype=bool)
|
|
254
|
+
|
|
255
|
+
solver_class = SOLVER_MAP.get(solver_name)
|
|
256
|
+
if solver_class is None:
|
|
257
|
+
raise ValueError(f"Unknown solver: {solver_name}")
|
|
258
|
+
solver = solver_class()
|
|
259
|
+
|
|
260
|
+
term = dfx.ODETerm(lambda t, y, args: f(t, y, *args))
|
|
261
|
+
stepsize_controller = dfx.PIDController(rtol=rtol, atol=atol)
|
|
262
|
+
|
|
263
|
+
solution = dfx.diffeqsolve(
|
|
264
|
+
term,
|
|
265
|
+
solver=solver,
|
|
266
|
+
t0=tau_0,
|
|
267
|
+
t1=tau_final,
|
|
268
|
+
dt0=(tau_final - tau_0) / 1,
|
|
269
|
+
y0=y_0,
|
|
270
|
+
args=args,
|
|
271
|
+
stepsize_controller=stepsize_controller,
|
|
272
|
+
saveat=dfx.SaveAt(dense=True),
|
|
273
|
+
**(extra_kwargs or {}),
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Evaluate all save_time points (static size), then mask them
|
|
277
|
+
all_evals = jax.vmap(solution.evaluate)(save_time) # shape: (MAX_TAU_LEN, n_states)
|
|
278
|
+
masked_array = jnp.where(mask[:, None], all_evals, jnp.zeros_like(all_evals))
|
|
279
|
+
# shape: (variable_len, n_states)
|
|
280
|
+
|
|
281
|
+
return masked_array
|