openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,357 @@
1
+ """Validation for bring-your-own-functions (byof).
2
+
3
+ This module provides validation for user-provided JAX functions in expert mode,
4
+ checking signatures, shapes, and differentiability before use.
5
+ """
6
+
7
+ import inspect
8
+ from typing import TYPE_CHECKING, List
9
+
10
+ if TYPE_CHECKING:
11
+ from openscvx.symbolic.expr.state import State
12
+
13
+ __all__ = ["validate_byof"]
14
+
15
+
16
+ def validate_byof(
17
+ byof: dict,
18
+ states: List["State"],
19
+ n_x: int,
20
+ n_u: int,
21
+ N: int = None,
22
+ ) -> None:
23
+ """Validate byof function signatures and shapes.
24
+
25
+ Checks that user-provided functions have the correct signatures and return
26
+ appropriate shapes. Performs validation before functions are used to provide
27
+ clear error messages.
28
+
29
+ Args:
30
+ byof: Dictionary of user-provided functions to validate
31
+ states: List of State objects for determining expected shapes
32
+ n_x: Total dimension of the unified state vector
33
+ n_u: Total dimension of the unified control vector
34
+ N: Number of nodes in the trajectory (optional). If provided, validates
35
+ node indices in nodal constraints.
36
+
37
+ Raises:
38
+ ValueError: If any function has invalid signature or returns wrong shape
39
+ TypeError: If functions are not callable
40
+
41
+ Example:
42
+ >>> validate_byof(byof, states, n_x=10, n_u=3, N=50) # Raises if invalid
43
+ """
44
+ import jax
45
+ import jax.numpy as jnp
46
+
47
+ # Validate byof keys
48
+ valid_keys = {"dynamics", "nodal_constraints", "cross_nodal_constraints", "ctcs_constraints"}
49
+ invalid_keys = set(byof.keys()) - valid_keys
50
+ if invalid_keys:
51
+ raise ValueError(f"Unknown byof keys: {invalid_keys}. Valid keys: {valid_keys}")
52
+
53
+ # Create dummy inputs for testing
54
+ dummy_x = jnp.zeros(n_x)
55
+ dummy_u = jnp.zeros(n_u)
56
+ dummy_node = 0
57
+ dummy_params = {}
58
+
59
+ # Validate dynamics functions
60
+ byof_dynamics = byof.get("dynamics", {})
61
+ if byof_dynamics:
62
+ # Build mapping from state name to expected shape
63
+ state_shapes = {state.name: state.shape for state in states}
64
+
65
+ for state_name, fn in byof_dynamics.items():
66
+ if state_name not in state_shapes:
67
+ raise ValueError(
68
+ f"byof dynamics '{state_name}' does not match any state name. "
69
+ f"Available states: {list(state_shapes.keys())}"
70
+ )
71
+
72
+ if not callable(fn):
73
+ raise TypeError(f"byof dynamics '{state_name}' must be callable, got {type(fn)}")
74
+
75
+ # Check signature
76
+ sig = inspect.signature(fn)
77
+ if len(sig.parameters) != 4:
78
+ raise ValueError(
79
+ f"byof dynamics '{state_name}' must have signature f(x, u, node, params), "
80
+ f"got {len(sig.parameters)} parameters: {list(sig.parameters.keys())}"
81
+ )
82
+
83
+ # Test call and check output shape
84
+ try:
85
+ result = fn(dummy_x, dummy_u, dummy_node, dummy_params)
86
+ except Exception as e:
87
+ raise ValueError(
88
+ f"byof dynamics '{state_name}' failed on test call with "
89
+ f"x.shape={dummy_x.shape}, u.shape={dummy_u.shape}: {e}"
90
+ ) from e
91
+
92
+ expected_shape = state_shapes[state_name]
93
+ result_shape = jnp.asarray(result).shape
94
+ if result_shape != expected_shape:
95
+ raise ValueError(
96
+ f"byof dynamics '{state_name}' returned shape {result_shape}, "
97
+ f"expected {expected_shape} (state '{state_name}' shape)"
98
+ )
99
+
100
+ # Test that gradient works (JAX compatibility check)
101
+ try:
102
+ jax.grad(lambda x: jnp.sum(fn(x, dummy_u, dummy_node, dummy_params)))(dummy_x)
103
+ except Exception as e:
104
+ raise ValueError(
105
+ f"byof dynamics '{state_name}' is not differentiable with JAX. "
106
+ f"Ensure the function uses JAX operations (jax.numpy, not numpy): {e}"
107
+ ) from e
108
+
109
+ # Validate nodal constraints
110
+ for i, constraint_spec in enumerate(byof.get("nodal_constraints", [])):
111
+ if not isinstance(constraint_spec, dict):
112
+ raise TypeError(
113
+ f"byof nodal_constraints[{i}] must be a dict (NodalConstraintSpec), "
114
+ f"got {type(constraint_spec)}"
115
+ )
116
+
117
+ if "constraint_fn" not in constraint_spec:
118
+ raise ValueError(f"byof nodal_constraints[{i}] missing required key 'constraint_fn'")
119
+
120
+ fn = constraint_spec["constraint_fn"]
121
+ if not callable(fn):
122
+ raise TypeError(
123
+ f"byof nodal_constraints[{i}]['constraint_fn'] must be callable, got {type(fn)}"
124
+ )
125
+
126
+ # Check signature
127
+ sig = inspect.signature(fn)
128
+ if len(sig.parameters) != 4:
129
+ raise ValueError(
130
+ f"byof nodal_constraints[{i}]['constraint_fn'] must have signature "
131
+ f"f(x, u, node, params), "
132
+ f"got {len(sig.parameters)} parameters: {list(sig.parameters.keys())}"
133
+ )
134
+
135
+ # Test call
136
+ try:
137
+ result = fn(dummy_x, dummy_u, dummy_node, dummy_params)
138
+ except Exception as e:
139
+ raise ValueError(
140
+ f"byof nodal_constraints[{i}]['constraint_fn'] failed on test call with "
141
+ f"x.shape={dummy_x.shape}, u.shape={dummy_u.shape}: {e}"
142
+ ) from e
143
+
144
+ # Check that result is array-like (can be scalar or vector)
145
+ try:
146
+ result_array = jnp.asarray(result)
147
+ except Exception as e:
148
+ raise ValueError(
149
+ f"byof nodal_constraints[{i}]['constraint_fn'] must return array-like value, "
150
+ f"got {type(result)}: {e}"
151
+ ) from e
152
+
153
+ # Test gradient
154
+ try:
155
+ jax.grad(lambda x: jnp.sum(fn(x, dummy_u, dummy_node, dummy_params)))(dummy_x)
156
+ except Exception as e:
157
+ raise ValueError(
158
+ f"byof nodal_constraints[{i}]['constraint_fn'] is not differentiable with JAX: {e}"
159
+ ) from e
160
+
161
+ # Validate nodes if provided
162
+ if "nodes" in constraint_spec:
163
+ nodes = constraint_spec["nodes"]
164
+ if not isinstance(nodes, (list, tuple)):
165
+ raise TypeError(
166
+ f"byof nodal_constraints[{i}]['nodes'] must be a list or tuple, "
167
+ f"got {type(nodes)}"
168
+ )
169
+ if len(nodes) == 0:
170
+ raise ValueError(f"byof nodal_constraints[{i}]['nodes'] cannot be empty")
171
+
172
+ # Validate node indices if N is provided
173
+ if N is not None:
174
+ for node in nodes:
175
+ # Handle negative indices (e.g., -1 for last node)
176
+ normalized_node = node if node >= 0 else N + node
177
+ # Validate range
178
+ if not (0 <= normalized_node < N):
179
+ raise ValueError(
180
+ f"byof nodal_constraints[{i}]['nodes'] contains invalid index {node} "
181
+ f"(normalized: {normalized_node}). Valid range is [0, {N}) or "
182
+ f"negative indices [-{N}, -1]."
183
+ )
184
+
185
+ # Validate cross-nodal constraints
186
+ dummy_X = jnp.zeros((10, n_x)) # Dummy trajectory with 10 nodes
187
+ dummy_U = jnp.zeros((10, n_u))
188
+
189
+ for i, fn in enumerate(byof.get("cross_nodal_constraints", [])):
190
+ if not callable(fn):
191
+ raise TypeError(f"byof cross_nodal_constraints[{i}] must be callable, got {type(fn)}")
192
+
193
+ # Check signature
194
+ sig = inspect.signature(fn)
195
+ if len(sig.parameters) != 3:
196
+ raise ValueError(
197
+ f"byof cross_nodal_constraints[{i}] must have signature f(X, U, params), "
198
+ f"got {len(sig.parameters)} parameters: {list(sig.parameters.keys())}"
199
+ )
200
+
201
+ # Test call
202
+ try:
203
+ result = fn(dummy_X, dummy_U, dummy_params)
204
+ except Exception as e:
205
+ raise ValueError(
206
+ f"byof cross_nodal_constraints[{i}] failed on test call with "
207
+ f"X.shape={dummy_X.shape}, U.shape={dummy_U.shape}: {e}"
208
+ ) from e
209
+
210
+ # Check that result is array-like
211
+ try:
212
+ result_array = jnp.asarray(result)
213
+ except Exception as e:
214
+ raise ValueError(
215
+ f"byof cross_nodal_constraints[{i}] must return array-like value, "
216
+ f"got {type(result)}: {e}"
217
+ ) from e
218
+
219
+ # Test gradient
220
+ try:
221
+ jax.grad(lambda X: jnp.sum(fn(X, dummy_U, dummy_params)))(dummy_X)
222
+ except Exception as e:
223
+ raise ValueError(
224
+ f"byof cross_nodal_constraints[{i}] is not differentiable with JAX: {e}"
225
+ ) from e
226
+
227
+ # Validate CTCS constraints
228
+ for i, ctcs_spec in enumerate(byof.get("ctcs_constraints", [])):
229
+ if not isinstance(ctcs_spec, dict):
230
+ raise TypeError(f"byof ctcs_constraints[{i}] must be a dict, got {type(ctcs_spec)}")
231
+
232
+ if "constraint_fn" not in ctcs_spec:
233
+ raise ValueError(f"byof ctcs_constraints[{i}] missing required key 'constraint_fn'")
234
+
235
+ fn = ctcs_spec["constraint_fn"]
236
+ if not callable(fn):
237
+ raise TypeError(
238
+ f"byof ctcs_constraints[{i}]['constraint_fn'] must be callable, got {type(fn)}"
239
+ )
240
+
241
+ # Check signature
242
+ sig = inspect.signature(fn)
243
+ if len(sig.parameters) != 4:
244
+ raise ValueError(
245
+ f"byof ctcs_constraints[{i}]['constraint_fn'] must have signature "
246
+ f"f(x, u, node, params), got {len(sig.parameters)} parameters: "
247
+ f"{list(sig.parameters.keys())}"
248
+ )
249
+
250
+ # Test call
251
+ try:
252
+ result = fn(dummy_x, dummy_u, dummy_node, dummy_params)
253
+ except Exception as e:
254
+ raise ValueError(
255
+ f"byof ctcs_constraints[{i}]['constraint_fn'] failed on test call: {e}"
256
+ ) from e
257
+
258
+ # Check that result is scalar
259
+ result_array = jnp.asarray(result)
260
+ if result_array.shape != ():
261
+ raise ValueError(
262
+ f"byof ctcs_constraints[{i}]['constraint_fn'] must return a scalar, "
263
+ f"got shape {result_array.shape}"
264
+ )
265
+
266
+ # Test gradient
267
+ try:
268
+ jax.grad(lambda x: fn(x, dummy_u, dummy_node, dummy_params))(dummy_x)
269
+ except Exception as e:
270
+ raise ValueError(
271
+ f"byof ctcs_constraints[{i}]['constraint_fn'] is not differentiable with JAX: {e}"
272
+ ) from e
273
+
274
+ # Validate penalty function if provided
275
+ if "penalty" in ctcs_spec:
276
+ penalty_spec = ctcs_spec["penalty"]
277
+ if callable(penalty_spec):
278
+ # Test custom penalty function
279
+ try:
280
+ test_residual = jnp.array(0.5)
281
+ penalty_result = penalty_spec(test_residual)
282
+ jnp.asarray(penalty_result)
283
+ except Exception as e:
284
+ raise ValueError(
285
+ f"byof ctcs_constraints[{i}]['penalty'] custom function failed: {e}"
286
+ ) from e
287
+ elif penalty_spec not in ["square", "l1", "huber"]:
288
+ raise ValueError(
289
+ f"byof ctcs_constraints[{i}]['penalty'] must be 'square', 'l1', 'huber', "
290
+ f"or a callable, got {penalty_spec!r}"
291
+ )
292
+
293
+ # Validate idx if provided
294
+ if "idx" in ctcs_spec:
295
+ idx = ctcs_spec["idx"]
296
+ if not isinstance(idx, int):
297
+ raise TypeError(
298
+ f"byof ctcs_constraints[{i}]['idx'] must be an integer, got {type(idx)}"
299
+ )
300
+ if idx < 0:
301
+ raise ValueError(
302
+ f"byof ctcs_constraints[{i}]['idx'] must be non-negative, got {idx}"
303
+ )
304
+
305
+ # Validate bounds if provided
306
+ if "bounds" in ctcs_spec:
307
+ bounds = ctcs_spec["bounds"]
308
+ if not isinstance(bounds, (tuple, list)) or len(bounds) != 2:
309
+ raise ValueError(
310
+ f"byof ctcs_constraints[{i}]['bounds'] must be a (min, max) tuple, got {bounds}"
311
+ )
312
+ if bounds[0] > bounds[1]:
313
+ raise ValueError(
314
+ f"byof ctcs_constraints[{i}]['bounds'] min ({bounds[0]}) must be <= "
315
+ f"max ({bounds[1]})"
316
+ )
317
+ else:
318
+ # Use default bounds for initial value validation
319
+ bounds = (0.0, 1e-4)
320
+
321
+ # Validate initial value is within bounds
322
+ if "initial" in ctcs_spec:
323
+ initial = ctcs_spec["initial"]
324
+ if not (bounds[0] <= initial <= bounds[1]):
325
+ raise ValueError(
326
+ f"byof ctcs_constraints[{i}]['initial'] ({initial}) must be within "
327
+ f"bounds [{bounds[0]}, {bounds[1]}]"
328
+ )
329
+
330
+ # Validate over (node interval) if provided
331
+ if "over" in ctcs_spec:
332
+ over = ctcs_spec["over"]
333
+ if not isinstance(over, (tuple, list)) or len(over) != 2:
334
+ raise ValueError(
335
+ f"byof ctcs_constraints[{i}]['over'] must be a (start, end) tuple, got {over}"
336
+ )
337
+ start, end = over
338
+ if not isinstance(start, int) or not isinstance(end, int):
339
+ raise TypeError(
340
+ f"byof ctcs_constraints[{i}]['over'] indices must be integers, "
341
+ f"got start={type(start)}, end={type(end)}"
342
+ )
343
+ if start >= end:
344
+ raise ValueError(
345
+ f"byof ctcs_constraints[{i}]['over'] start ({start}) must be < end ({end})"
346
+ )
347
+ if start < 0:
348
+ raise ValueError(
349
+ f"byof ctcs_constraints[{i}]['over'] start ({start}) must be non-negative"
350
+ )
351
+ # Validate against trajectory length if N is provided
352
+ if N is not None:
353
+ if end > N:
354
+ raise ValueError(
355
+ f"byof ctcs_constraints[{i}]['over'] end ({end}) exceeds "
356
+ f"trajectory length ({N})"
357
+ )
@@ -0,0 +1,48 @@
1
+ """Numerical integration schemes for trajectory optimization.
2
+
3
+ This module provides implementations of numerical integrators used for simulating
4
+ continuous-time dynamics.
5
+
6
+ Current Implementations:
7
+ RK45 Integration: Explicit Runge-Kutta-Fehlberg method (4th/5th order)
8
+ with both fixed-step and adaptive implementations via Diffrax.
9
+ Supports a variety of explicit and implicit ODE solvers through the
10
+ Diffrax backend (Dopri5/8, Tsit5, KenCarp3/4/5, etc.).
11
+
12
+ Planned Architecture (ABC-based):
13
+
14
+ A base class will be introduced to enable pluggable integrator implementations.
15
+ This will enable users to implement custom integrators.
16
+ Future integrators will implement the Integrator interface:
17
+
18
+ ```python
19
+ # integrators/base.py (planned):
20
+ class Integrator(ABC):
21
+ @abstractmethod
22
+ def step(self, f: Callable, x: Array, u: Array, t: float, dt: float) -> Array:
23
+ '''Take one integration step from state x at time t with step dt.'''
24
+ ...
25
+
26
+ @abstractmethod
27
+ def integrate(self, f: Callable, x0: Array, u_traj: Array,
28
+ t_span: tuple[float, float], num_steps: int) -> Array:
29
+ '''Integrate over a time span with given control trajectory.'''
30
+ ...
31
+ ```
32
+ """
33
+
34
+ from .runge_kutta import (
35
+ SOLVER_MAP,
36
+ rk45_step,
37
+ solve_ivp_diffrax,
38
+ solve_ivp_diffrax_prop,
39
+ solve_ivp_rk45,
40
+ )
41
+
42
+ __all__ = [
43
+ "SOLVER_MAP",
44
+ "rk45_step",
45
+ "solve_ivp_rk45",
46
+ "solve_ivp_diffrax",
47
+ "solve_ivp_diffrax_prop",
48
+ ]
@@ -0,0 +1,281 @@
1
+ import os
2
+ from typing import Any, Callable
3
+
4
+ import diffrax as dfx
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from diffrax._global_interpolation import DenseInterpolation
8
+ from jax import tree_util
9
+
10
+ os.environ["EQX_ON_ERROR"] = "nan"
11
+
12
+
13
+ # Safely check if DenseInterpolation is already registered
14
+ try:
15
+ # Attempt to flatten a dummy DenseInterpolation instance
16
+ # Provide dummy arguments to create a valid instance
17
+ dummy_instance = DenseInterpolation(
18
+ ts=jnp.array([]),
19
+ ts_size=0,
20
+ infos=None,
21
+ interpolation_cls=None,
22
+ direction=None,
23
+ t0_if_trivial=0.0,
24
+ y0_if_trivial=jnp.array([]),
25
+ )
26
+ tree_util.tree_flatten(dummy_instance)
27
+ except ValueError:
28
+ # Register DenseInterpolation as a PyTree node if not already registered
29
+ def dense_interpolation_flatten(obj):
30
+ # Flatten the internal data of DenseInterpolation
31
+ return (obj._data,), None
32
+
33
+ def dense_interpolation_unflatten(aux_data, children):
34
+ # Reconstruct DenseInterpolation from its flattened data
35
+ return DenseInterpolation(*children)
36
+
37
+ tree_util.register_pytree_node(
38
+ DenseInterpolation,
39
+ dense_interpolation_flatten,
40
+ dense_interpolation_unflatten,
41
+ )
42
+
43
+ SOLVER_MAP = {
44
+ "Tsit5": dfx.Tsit5,
45
+ "Euler": dfx.Euler,
46
+ "Heun": dfx.Heun,
47
+ "Midpoint": dfx.Midpoint,
48
+ "Ralston": dfx.Ralston,
49
+ "Dopri5": dfx.Dopri5,
50
+ "Dopri8": dfx.Dopri8,
51
+ "Bosh3": dfx.Bosh3,
52
+ "ReversibleHeun": dfx.ReversibleHeun,
53
+ "ImplicitEuler": dfx.ImplicitEuler,
54
+ "KenCarp3": dfx.KenCarp3,
55
+ "KenCarp4": dfx.KenCarp4,
56
+ "KenCarp5": dfx.KenCarp5,
57
+ }
58
+
59
+
60
+ # fmt: off
61
+ def rk45_step(
62
+ f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
63
+ t: jnp.ndarray,
64
+ y: jnp.ndarray,
65
+ h: float,
66
+ *args
67
+ ) -> jnp.ndarray:
68
+ """
69
+ Perform a single RK45 (Runge-Kutta-Fehlberg) integration step.
70
+
71
+ This implements the classic Dorman-Prince coefficients for an
72
+ explicit 4(5) method, returning the fourth-order estimate.
73
+
74
+ Args:
75
+ f (Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]):
76
+ ODE right-hand side; signature f(t, y, *args) -> dy/dt.
77
+ t (jnp.ndarray): Current time.
78
+ y (jnp.ndarray): Current state vector.
79
+ h (float): Step size.
80
+ *args: Additional arguments passed to `f`.
81
+
82
+ Returns:
83
+ jnp.ndarray: Next state estimate at t + h.
84
+ """
85
+ k1 = f(t, y, *args)
86
+ k2 = f(t + h/4, y + h*k1/4, *args)
87
+ k3 = f(t + 3*h/8, y + 3*h*k1/32 + 9*h*k2/32, *args)
88
+ k4 = f(t + 12*h/13, y + 1932*h*k1/2197 - 7200*h*k2/2197 + 7296*h*k3/2197, *args)
89
+ k5 = f(t + h, y + 439*h*k1/216 - 8*h*k2 + 3680*h*k3/513 - 845*h*k4/4104, *args)
90
+ y_next = y + h * (25*k1/216 + 1408*k3/2565 + 2197*k4/4104 - k5/5)
91
+ return y_next
92
+ # fmt: on
93
+
94
+
95
+ def solve_ivp_rk45(
96
+ f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
97
+ tau_final: float,
98
+ y_0: jnp.ndarray,
99
+ args,
100
+ tau_0: float = 0.0,
101
+ num_substeps: int = 50,
102
+ is_not_compiled: bool = False,
103
+ ):
104
+ """
105
+ Solve an initial-value ODE problem using fixed-step RK45 integration.
106
+
107
+ Args:
108
+ f (Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]):
109
+ ODE right-hand side; signature f(t, y, *args) -> dy/dt.
110
+ tau_final (float): Final integration time.
111
+ y_0 (jnp.ndarray): Initial state at tau_0.
112
+ args (tuple): Extra arguments to pass to `f`.
113
+ tau_0 (float, optional): Initial time. Defaults to 0.0.
114
+ num_substeps (int, optional): Number of output time points. Defaults to 50.
115
+ is_not_compiled (bool, optional): If True, use Python loop instead of
116
+ JAX `lax.fori_loop`. Defaults to False.
117
+
118
+ Returns:
119
+ jnp.ndarray: Array of shape (num_substeps, state_dim) with solution at each time.
120
+ """
121
+ substeps = jnp.linspace(tau_0, tau_final, num_substeps)
122
+
123
+ h = (tau_final - tau_0) / (len(substeps) - 1)
124
+ solution = jnp.zeros((len(substeps), len(y_0)))
125
+ solution = solution.at[0].set(y_0)
126
+
127
+ if is_not_compiled:
128
+ for i in range(1, len(substeps)):
129
+ t = tau_0 + i * h
130
+ solution = solution.at[i].set(rk45_step(f, t, solution[i - 1], h, *args))
131
+ else:
132
+
133
+ def body_fun(i, val):
134
+ t, y, V_result = val
135
+ y_next = rk45_step(f, t, y, h, *args)
136
+ V_result = V_result.at[i].set(y_next)
137
+ return (t + h, y_next, V_result)
138
+
139
+ _, _, solution = jax.lax.fori_loop(1, len(substeps), body_fun, (tau_0, y_0, solution))
140
+
141
+ return solution
142
+
143
+
144
+ def solve_ivp_diffrax(
145
+ f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
146
+ tau_final: float,
147
+ y_0: jnp.ndarray,
148
+ args,
149
+ tau_0: float = 0.0,
150
+ num_substeps: int = 50,
151
+ solver_name: str = "Dopri8",
152
+ rtol: float = 1e-3,
153
+ atol: float = 1e-6,
154
+ extra_kwargs=None,
155
+ ):
156
+ """
157
+ Solve an initial-value ODE problem using a Diffrax adaptive solver.
158
+
159
+ Args:
160
+ f (Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]):
161
+ ODE right-hand side; f(t, y, *args).
162
+ tau_final (float): Final integration time.
163
+ y_0 (jnp.ndarray): Initial state at tau_0.
164
+ args (tuple): Extra arguments to pass to `f` in the solver term.
165
+ tau_0 (float, optional): Initial time. Defaults to 0.0.
166
+ num_substeps (int, optional): Number of save points between tau_0 and tau_final.
167
+ Defaults to 50.
168
+ solver_name (str, optional): Key into SOLVER_MAP for the Diffrax solver class.
169
+ Defaults to "Dopri8".
170
+ rtol (float, optional): Relative tolerance for adaptive stepping. Defaults to 1e-3.
171
+ atol (float, optional): Absolute tolerance for adaptive stepping. Defaults to 1e-6.
172
+ extra_kwargs (dict, optional): Additional keyword arguments forwarded to `diffeqsolve`.
173
+
174
+ Returns:
175
+ jnp.ndarray: Solution states at the requested save points, shape (num_substeps, state_dim).
176
+
177
+ Raises:
178
+ ValueError: If `solver_name` is not in SOLVER_MAP.
179
+ """
180
+ substeps = jnp.linspace(tau_0, tau_final, num_substeps)
181
+
182
+ solver_class = SOLVER_MAP.get(solver_name)
183
+ if solver_class is None:
184
+ raise ValueError(f"Unknown solver: {solver_name}")
185
+ solver = solver_class()
186
+
187
+ term = dfx.ODETerm(lambda t, y, args: f(t, y, *args))
188
+ stepsize_controller = dfx.PIDController(rtol=rtol, atol=atol)
189
+ solution = dfx.diffeqsolve(
190
+ term,
191
+ solver=solver,
192
+ t0=tau_0,
193
+ t1=tau_final,
194
+ dt0=(tau_final - tau_0) / (len(substeps) - 1),
195
+ y0=y_0,
196
+ args=args,
197
+ stepsize_controller=stepsize_controller,
198
+ saveat=dfx.SaveAt(ts=substeps),
199
+ progress_meter=dfx.NoProgressMeter(),
200
+ **(extra_kwargs or {}),
201
+ )
202
+
203
+ return solution.ys
204
+
205
+
206
+ def solve_ivp_diffrax_prop(
207
+ f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
208
+ tau_final: float,
209
+ y_0: jnp.ndarray,
210
+ args,
211
+ tau_0: float = 0.0,
212
+ num_substeps: int = 50,
213
+ solver_name: str = "Dopri8",
214
+ rtol: float = 1e-3,
215
+ atol: float = 1e-6,
216
+ extra_kwargs=None,
217
+ save_time: jnp.ndarray = None,
218
+ mask: jnp.ndarray = None,
219
+ ):
220
+ """
221
+ Solve an initial-value ODE problem using a Diffrax adaptive solver.
222
+ This function is specifically designed for use in the context of
223
+ trajectory optimization and handles the nonlinear single-shot propagation
224
+ of state variables in undilated time.
225
+
226
+ Args:
227
+ f (Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]): ODE right-hand side;
228
+ signature f(t, y, *args) -> dy/dt.
229
+ tau_final (float): Final integration time.
230
+ y_0 (jnp.ndarray): Initial state at tau_0.
231
+ args (tuple): Extra arguments to pass to `f` in the solver term.
232
+ tau_0 (float, optional): Initial time. Defaults to 0.0.
233
+ num_substeps (int, optional): Number of save points between tau_0 and tau_final.
234
+ Defaults to 50.
235
+ solver_name (str, optional): Key into SOLVER_MAP for the Diffrax solver class.
236
+ Defaults to "Dopri8".
237
+ rtol (float, optional): Relative tolerance for adaptive stepping. Defaults to 1e-3.
238
+ atol (float, optional): Absolute tolerance for adaptive stepping. Defaults to 1e-6.
239
+ extra_kwargs (dict, optional): Additional keyword arguments forwarded to `diffeqsolve`.
240
+ save_time (jnp.ndarray, optional): Time points at which to evaluate the solution.
241
+ Must be provided for export compatibility.
242
+ mask (jnp.ndarray, optional): Boolean mask for the save_time points.
243
+
244
+ Returns:
245
+ jnp.ndarray: Solution states at the requested save points, shape (num_substeps, state_dim).
246
+ Raises:
247
+ ValueError: If `solver_name` is not in SOLVER_MAP or if save_time is not provided.
248
+ """
249
+
250
+ if save_time is None:
251
+ raise ValueError("save_time must be provided for export compatibility.")
252
+ if mask is None:
253
+ mask = jnp.ones_like(save_time, dtype=bool)
254
+
255
+ solver_class = SOLVER_MAP.get(solver_name)
256
+ if solver_class is None:
257
+ raise ValueError(f"Unknown solver: {solver_name}")
258
+ solver = solver_class()
259
+
260
+ term = dfx.ODETerm(lambda t, y, args: f(t, y, *args))
261
+ stepsize_controller = dfx.PIDController(rtol=rtol, atol=atol)
262
+
263
+ solution = dfx.diffeqsolve(
264
+ term,
265
+ solver=solver,
266
+ t0=tau_0,
267
+ t1=tau_final,
268
+ dt0=(tau_final - tau_0) / 1,
269
+ y0=y_0,
270
+ args=args,
271
+ stepsize_controller=stepsize_controller,
272
+ saveat=dfx.SaveAt(dense=True),
273
+ **(extra_kwargs or {}),
274
+ )
275
+
276
+ # Evaluate all save_time points (static size), then mask them
277
+ all_evals = jax.vmap(solution.evaluate)(save_time) # shape: (MAX_TAU_LEN, n_states)
278
+ masked_array = jnp.where(mask[:, None], all_evals, jnp.zeros_like(all_evals))
279
+ # shape: (variable_len, n_states)
280
+
281
+ return masked_array