openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,236 @@
1
+ import jax.numpy as jnp
2
+ import numpy as np
3
+
4
+ from openscvx.config import Config
5
+ from openscvx.integrators import solve_ivp_diffrax, solve_ivp_rk45
6
+ from openscvx.lowered import Dynamics
7
+
8
+
9
+ def dVdt(
10
+ tau: float,
11
+ V: jnp.ndarray,
12
+ u_cur: np.ndarray,
13
+ u_next: np.ndarray,
14
+ state_dot: callable,
15
+ A: callable,
16
+ B: callable,
17
+ n_x: int,
18
+ n_u: int,
19
+ N: int,
20
+ dis_type: str,
21
+ params: dict,
22
+ ) -> jnp.ndarray:
23
+ """Compute the time derivative of the augmented state vector.
24
+
25
+ This function computes the time derivative of the augmented state vector V,
26
+ which includes the state, state transition matrix, and control influence matrix.
27
+
28
+ Args:
29
+ tau (float): Current normalized time in [0,1].
30
+ V (jnp.ndarray): Augmented state vector.
31
+ u_cur (np.ndarray): Control input at current node.
32
+ u_next (np.ndarray): Control input at next node.
33
+ state_dot (callable): Function computing state derivatives.
34
+ A (callable): Function computing state Jacobian.
35
+ B (callable): Function computing control Jacobian.
36
+ n_x (int): Number of states.
37
+ n_u (int): Number of controls.
38
+ N (int): Number of nodes in trajectory.
39
+ dis_type (str): Discretization type ("ZOH" or "FOH").
40
+ **params: Additional parameters passed to state_dot, A, and B.
41
+
42
+ Returns:
43
+ jnp.ndarray: Time derivative of augmented state vector.
44
+ """
45
+ # Define the nodes
46
+ nodes = jnp.arange(0, N - 1)
47
+
48
+ # Define indices for slicing the augmented state vector
49
+ i0 = 0
50
+ i1 = n_x
51
+ i2 = i1 + n_x * n_x
52
+ i3 = i2 + n_x * n_u
53
+ i4 = i3 + n_x * n_u
54
+
55
+ # Unflatten V
56
+ V = V.reshape(-1, i4)
57
+
58
+ # Compute the interpolation factor based on the discretization type
59
+ if dis_type == "ZOH":
60
+ beta = 0.0
61
+ elif dis_type == "FOH":
62
+ beta = (tau) * N
63
+ alpha = 1 - beta
64
+
65
+ # Interpolate the control input
66
+ u = u_cur + beta * (u_next - u_cur)
67
+ s = u[:, -1]
68
+
69
+ # Initialize the augmented Jacobians
70
+ dfdx = jnp.zeros((V.shape[0], n_x, n_x))
71
+ dfdu = jnp.zeros((V.shape[0], n_x, n_u))
72
+
73
+ # Ensure x_seq and u have the same batch size
74
+ x = V[:, :n_x]
75
+ u = u[: x.shape[0]]
76
+
77
+ # Compute the nonlinear propagation term
78
+ f = state_dot(x, u[:, :-1], nodes, params)
79
+ F = s[:, None] * f
80
+
81
+ # Evaluate the State Jacobian
82
+ dfdx = A(x, u[:, :-1], nodes, params)
83
+ sdfdx = s[:, None, None] * dfdx
84
+
85
+ # Evaluate the Control Jacobian
86
+ dfdu_veh = B(x, u[:, :-1], nodes, params)
87
+ dfdu = dfdu.at[:, :, :-1].set(s[:, None, None] * dfdu_veh)
88
+ dfdu = dfdu.at[:, :, -1].set(f)
89
+
90
+ # Stack up the results into the augmented state vector
91
+ # fmt: off
92
+ dVdt = jnp.zeros_like(V)
93
+ dVdt = dVdt.at[:, i0:i1].set(F)
94
+ dVdt = dVdt.at[:, i1:i2].set(
95
+ jnp.matmul(sdfdx, V[:, i1:i2].reshape(-1, n_x, n_x)).reshape(-1, n_x * n_x)
96
+ )
97
+ dVdt = dVdt.at[:, i2:i3].set(
98
+ (jnp.matmul(sdfdx, V[:, i2:i3].reshape(-1, n_x, n_u)) + dfdu * alpha).reshape(-1, n_x * n_u)
99
+ )
100
+ dVdt = dVdt.at[:, i3:i4].set(
101
+ (jnp.matmul(sdfdx, V[:, i3:i4].reshape(-1, n_x, n_u)) + dfdu * beta).reshape(-1, n_x * n_u)
102
+ )
103
+ # fmt: on
104
+
105
+ return dVdt.reshape(-1)
106
+
107
+
108
+ def calculate_discretization(
109
+ x,
110
+ u,
111
+ state_dot: callable,
112
+ A: callable,
113
+ B: callable,
114
+ settings: Config,
115
+ params: dict,
116
+ ):
117
+ """Calculate the discretized system matrices.
118
+
119
+ This function computes the discretized system matrices (A_bar, B_bar, C_bar)
120
+ and defect vector (z_bar) using numerical integration.
121
+
122
+ Args:
123
+ x: State trajectory.
124
+ u: Control trajectory.
125
+ state_dot (callable): Function computing state derivatives.
126
+ A (callable): Function computing state Jacobian.
127
+ B (callable): Function computing control Jacobian.
128
+ settings: Configuration settings for OpenSCvx.
129
+ custom_integrator (bool): Whether to use custom RK45 integrator.
130
+ debug (bool): Whether to use debug mode.
131
+ solver (str): Name of the solver to use.
132
+ rtol (float): Relative tolerance for integration.
133
+ atol (float): Absolute tolerance for integration.
134
+ dis_type (str): Discretization type ("ZOH" or "FOH").
135
+ **kwargs: Additional parameters passed to state_dot, A, and B.
136
+
137
+ Returns:
138
+ tuple: (A_bar, B_bar, C_bar, z_bar, Vmulti) where:
139
+ - A_bar: Discretized state transition matrix
140
+ - B_bar: Discretized control influence matrix
141
+ - C_bar: Discretized control influence matrix for next node
142
+ - z_bar: Defect vector
143
+ - Vmulti: Full augmented state trajectory
144
+ """
145
+ # Unpack settings
146
+ n_x = settings.sim.n_states
147
+ n_u = settings.sim.n_controls
148
+
149
+ N = settings.scp.n
150
+
151
+ # Define indices for slicing the augmented state vector
152
+ i0 = 0
153
+ i1 = n_x
154
+ i2 = i1 + n_x * n_x
155
+ i3 = i2 + n_x * n_u
156
+ i4 = i3 + n_x * n_u
157
+
158
+ # Initial augmented state
159
+ V0 = jnp.zeros((N - 1, i4))
160
+ V0 = V0.at[:, :n_x].set(x[:-1].astype(float))
161
+ V0 = V0.at[:, n_x : n_x + n_x * n_x].set(jnp.eye(n_x).reshape(1, -1).repeat(N - 1, axis=0))
162
+
163
+ # Choose integrator
164
+ integrator_args = dict(
165
+ u_cur=u[:-1].astype(float),
166
+ u_next=u[1:].astype(float),
167
+ state_dot=state_dot,
168
+ A=A,
169
+ B=B,
170
+ n_x=n_x,
171
+ n_u=n_u,
172
+ N=N,
173
+ dis_type=settings.dis.dis_type,
174
+ params=params, # Pass params as single dict
175
+ )
176
+
177
+ # Define dVdt wrapper using named arguments
178
+ def dVdt_wrapped(t, y):
179
+ return dVdt(t, y, **integrator_args)
180
+
181
+ # Choose integrator
182
+ if settings.dis.custom_integrator:
183
+ sol = solve_ivp_rk45(
184
+ dVdt_wrapped,
185
+ 1.0 / (N - 1),
186
+ V0.reshape(-1),
187
+ args=(),
188
+ is_not_compiled=settings.dev.debug,
189
+ )
190
+ else:
191
+ sol = solve_ivp_diffrax(
192
+ dVdt_wrapped,
193
+ 1.0 / (N - 1),
194
+ V0.reshape(-1),
195
+ solver_name=settings.dis.solver,
196
+ rtol=settings.dis.rtol,
197
+ atol=settings.dis.atol,
198
+ args=(),
199
+ extra_kwargs=settings.dis.args,
200
+ )
201
+
202
+ Vend = sol[-1].T.reshape(-1, i4)
203
+ Vmulti = sol.T
204
+
205
+ x_prop = Vend[:, i0:i1]
206
+
207
+ # Return as 3D arrays: (N-1, n_x, n_x) for A_bar, (N-1, n_x, n_u) for B_bar/C_bar
208
+ A_bar = Vend[:, i1:i2].reshape(N - 1, n_x, n_x)
209
+ B_bar = Vend[:, i2:i3].reshape(N - 1, n_x, n_u)
210
+ C_bar = Vend[:, i3:i4].reshape(N - 1, n_x, n_u)
211
+
212
+ return A_bar, B_bar, C_bar, x_prop, Vmulti
213
+
214
+
215
+ def get_discretization_solver(dyn: Dynamics, settings: Config):
216
+ """Create a discretization solver function.
217
+
218
+ This function creates a solver that computes the discretized system matrices
219
+ using the specified dynamics and settings.
220
+
221
+ Args:
222
+ dyn (Dynamics): System dynamics object.
223
+ settings: Configuration settings for discretization.
224
+
225
+ Returns:
226
+ callable: A function that computes the discretized system matrices.
227
+ """
228
+ return lambda x, u, params: calculate_discretization(
229
+ x=x,
230
+ u=u,
231
+ state_dot=dyn.f,
232
+ A=dyn.A,
233
+ B=dyn.B,
234
+ settings=settings,
235
+ params=params, # Pass as single dict
236
+ )
@@ -0,0 +1,23 @@
1
+ """Expert-mode features for advanced users.
2
+
3
+ This module contains features for expert users who need fine-grained control
4
+ and are willing to bypass higher-level abstractions.
5
+ """
6
+
7
+ from openscvx.expert.byof import (
8
+ ByofSpec,
9
+ CtcsConstraintSpec,
10
+ NodalConstraintSpec,
11
+ PenaltyFunction,
12
+ )
13
+ from openscvx.expert.lowering import apply_byof
14
+ from openscvx.expert.validation import validate_byof
15
+
16
+ __all__ = [
17
+ "ByofSpec",
18
+ "CtcsConstraintSpec",
19
+ "NodalConstraintSpec",
20
+ "PenaltyFunction",
21
+ "apply_byof",
22
+ "validate_byof",
23
+ ]
@@ -0,0 +1,326 @@
1
+ """Bring-Your-Own-Functions (BYOF) - Expert User Mode.
2
+
3
+ This module provides type definitions and documentation for expert users who want
4
+ to bypass the symbolic layer and directly provide raw JAX functions.
5
+
6
+ Important:
7
+ The unified state/control vectors include ALL states/controls in the order
8
+ they were provided, plus any augmented states from CTCS constraints. You are
9
+ responsible for correct indexing. Consider inspecting the symbolic problem
10
+ to understand the layout.
11
+
12
+ Warning:
13
+ **Constraint Sign Convention**: All constraints follow g(x,u) <= 0 convention.
14
+ Return **negative when satisfied**, **positive when violated**.
15
+ Example: for x <= 10 return ``x - 10``, for x >= 5 return ``5 - x``.
16
+
17
+ Function Signatures:
18
+ All byof functions must be JAX-compatible (use jax.numpy, avoid side effects).
19
+
20
+ - dynamics: ``(x, u, node, params) -> xdot_component``
21
+ - x: Full unified state vector (1D array)
22
+ - u: Full unified control vector (1D array)
23
+ - node: Integer node index
24
+ - params: Dict of parameters
25
+ - Returns: State derivative component (array matching state shape)
26
+
27
+ - nodal_constraints: ``(x, u, node, params) -> residual``
28
+ - Same arguments as dynamics
29
+ - Returns: Constraint residual (g <= 0: negative=satisfied, positive=violated)
30
+
31
+ - cross_nodal_constraints: ``(X, U, params) -> residual``
32
+ - X: State trajectory (N, n_x) where N is number of trajectory nodes,
33
+ n_x is unified state dimension
34
+ - U: Control trajectory (N, n_u) where N is number of trajectory nodes,
35
+ n_u is unified control dimension
36
+ - params: Dict of parameters
37
+ - Returns: Constraint residual (g <= 0: negative=satisfied, positive=violated)
38
+
39
+ - ctcs constraint_fn: ``(x, u, node, params) -> scalar``
40
+ - Same as nodal_constraints but MUST return scalar
41
+ - Returns: Constraint residual (g <= 0: negative=satisfied, positive=violated)
42
+
43
+ - ctcs penalty: ``(residual) -> penalty_value``
44
+ - residual: Scalar constraint residual
45
+ - Returns: Non-negative penalty value
46
+
47
+ Example:
48
+ Basic usage mixing symbolic and byof::
49
+
50
+ import jax.numpy as jnp
51
+ import openscvx as ox
52
+ from openscvx import ByofSpec
53
+
54
+ # Define states
55
+ position = ox.State("position", shape=(2,))
56
+ velocity = ox.State("velocity", shape=(1,))
57
+ theta = ox.Control("theta", shape=(1,))
58
+
59
+ # Unified state: [position[0], position[1], velocity[0], time, augmented...]
60
+ # Unified control: [theta[0], time_dilation]
61
+
62
+ # Tip: Use the .slice property on State/Control objects for cleaner,
63
+ # more maintainable indexing instead of hardcoded indices.
64
+ byof: ByofSpec = {
65
+ "nodal_constraints": [
66
+ # Velocity bounds (applied to all nodes)
67
+ {
68
+ "constraint_fn": lambda x, u, node, params: x[velocity.slice][0] - 10.0,
69
+ },
70
+ {
71
+ "constraint_fn": lambda x, u, node, params: -x[velocity.slice][0],
72
+ },
73
+ # Velocity must be exactly 0 at start (selective enforcement)
74
+ {
75
+ "constraint_fn": lambda x, u, node, params: x[velocity.slice][0],
76
+ "nodes": [0], # Only at first node
77
+ },
78
+ ],
79
+ "ctcs_constraints": [
80
+ {
81
+ "constraint_fn": lambda x, u, node, params: x[position.slice][0] - 10.0,
82
+ "penalty": "square",
83
+ "bounds": (0.0, 1e-4),
84
+ }
85
+ ],
86
+ }
87
+
88
+ problem = ox.Problem(..., byof=byof)
89
+ """
90
+
91
+ from typing import TYPE_CHECKING, Any, Callable, List, Literal, Tuple, TypedDict, Union
92
+
93
+ if TYPE_CHECKING:
94
+ from jax import Array as JaxArray
95
+ else:
96
+ JaxArray = Any
97
+
98
+ __all__ = ["ByofSpec", "CtcsConstraintSpec", "NodalConstraintSpec", "PenaltyFunction"]
99
+
100
+
101
+ # Type aliases for clarity
102
+ DynamicsFunction = Callable[[JaxArray, JaxArray, int, dict], JaxArray]
103
+ NodalConstraintFunction = Callable[[JaxArray, JaxArray, int, dict], JaxArray]
104
+ CrossNodalConstraintFunction = Callable[[JaxArray, JaxArray, dict], JaxArray]
105
+ CtcsConstraintFunction = Callable[[JaxArray, JaxArray, int, dict], float]
106
+ PenaltyFunction = Union[Literal["square", "l1", "huber"], Callable[[float], float]]
107
+
108
+
109
+ class NodalConstraintSpec(TypedDict, total=False):
110
+ """Specification for nodal constraint with optional node selection.
111
+
112
+ Nodal constraints are point-wise constraints evaluated at specific trajectory nodes.
113
+ By default, constraints apply to all nodes, but you can restrict enforcement to
114
+ specific nodes for boundary conditions, waypoints, or computational efficiency.
115
+
116
+ Attributes:
117
+ constraint_fn: Constraint function with signature ``(x, u, node, params) -> residual``.
118
+ Follows g(x,u) <= 0 convention (negative = satisfied). Required field.
119
+ nodes: List of integer node indices where constraint is enforced.
120
+ If omitted, applies to all nodes. Negative indices supported (e.g., -1 for last).
121
+ Optional field.
122
+
123
+ Example:
124
+ Boundary constraint only at first and last nodes::
125
+
126
+ nodal_spec: NodalConstraintSpec = {
127
+ "constraint_fn": lambda x, u, node, params: x[velocity.slice][0],
128
+ "nodes": [0, -1], # Only at start and end
129
+ }
130
+
131
+ Waypoint constraint at middle of trajectory::
132
+
133
+ nodal_spec: NodalConstraintSpec = {
134
+ "constraint_fn": lambda x, u, node, params: jnp.linalg.norm(
135
+ x[position.slice] - jnp.array([5.0, 7.5])
136
+ ) - 0.1,
137
+ "nodes": [N // 2],
138
+ }
139
+ """
140
+
141
+ constraint_fn: NodalConstraintFunction # Required
142
+ nodes: List[int]
143
+
144
+
145
+ class CtcsConstraintSpec(TypedDict, total=False):
146
+ """Specification for CTCS (Continuous-Time Constraint Satisfaction) constraint.
147
+
148
+ CTCS constraints are enforced by augmenting the dynamics with a penalty term that
149
+ accumulates violations over time. Useful for path constraints that must be satisfied
150
+ continuously, not just at discrete nodes.
151
+
152
+ Attributes:
153
+ constraint_fn: Function computing constraint residual with signature
154
+ ``(x, u, node, params) -> scalar``. Must return scalar.
155
+ Follows g(x,u) <= 0 convention (negative = satisfied). Required field.
156
+ penalty: Penalty function for positive residuals (violations).
157
+ Built-in options: "square" (max(r,0)^2, default), "l1" (max(r,0)),
158
+ "huber" (Huber loss). Custom: Callable ``(r) -> penalty`` (non-negative,
159
+ differentiable).
160
+ bounds: (min, max) bounds for augmented state accumulating penalties.
161
+ Default: (0.0, 1e-4). Max acts as soft constraint on total violation.
162
+ initial: Initial value for augmented state. Default: bounds[0] (usually 0.0).
163
+ over: Node interval (start, end) where constraint is active. The constraint
164
+ is enforced for nodes in [start, end). If omitted, constraint is active
165
+ over all nodes. Matches symbolic `.over()` method behavior.
166
+ idx: Constraint group index for sharing augmented states (default: 0).
167
+ All CTCS constraints (symbolic and byof) with the same idx share a single
168
+ augmented state. Their penalties are summed together. Use different idx values
169
+ to track different types of violations separately.
170
+
171
+ Warning:
172
+ If symbolic CTCS constraints exist with idx values [0, 1, 2], then byof idx **must** either:
173
+
174
+ - Match an existing idx (e.g., 0, 1, or 2) to add to that augmented state
175
+ - Be sequential after them (e.g., 3, 4, 5) to create new augmented states
176
+
177
+ You cannot use idx values that create gaps (e.g., if symbolic has [0, 1],
178
+ you cannot use byof idx=3 without also using idx=2).
179
+
180
+ Example:
181
+ Enforce position[0] <= 10.0 continuously::
182
+
183
+ # Assuming position = ox.State("position", shape=(2,))
184
+ ctcs_spec: CtcsConstraintSpec = {
185
+ "constraint_fn": lambda x, u, node, params: x[position.slice][0] - 10.0,
186
+ "penalty": "square",
187
+ "bounds": (0.0, 1e-4),
188
+ "initial": 0.0,
189
+ "idx": 0, # Groups with other constraints having idx=0
190
+ }
191
+
192
+ Enforce constraint only over specific node range::
193
+
194
+ ctcs_spec: CtcsConstraintSpec = {
195
+ "constraint_fn": lambda x, u, node, params: x[position.slice][0] - 10.0,
196
+ "over": (10, 50), # Active only for nodes 10-49
197
+ "penalty": "square",
198
+ }
199
+
200
+ Multiple constraints sharing an augmented state::
201
+
202
+ # If symbolic CTCS already has idx=[0, 1], then:
203
+
204
+ byof = {
205
+ "ctcs_constraints": [
206
+ # Add to existing symbolic idx=0 augmented state
207
+ {
208
+ "constraint_fn": lambda x, u, node, params: x[pos.slice][0] - 10.0,
209
+ "idx": 0, # Shares with symbolic idx=0
210
+ },
211
+ # Add to existing symbolic idx=1 augmented state
212
+ {
213
+ "constraint_fn": lambda x, u, node, params: x[vel.slice][0] - 5.0,
214
+ "idx": 1, # Shares with symbolic idx=1
215
+ },
216
+ # Create NEW augmented state (sequential after symbolic)
217
+ {
218
+ "constraint_fn": lambda x, u, node, params: x[pos.slice][1] - 8.0,
219
+ "idx": 2, # New state (symbolic has 0,1, so next is 2)
220
+ },
221
+ ]
222
+ }
223
+ """
224
+
225
+ constraint_fn: CtcsConstraintFunction # Required
226
+ penalty: PenaltyFunction
227
+ bounds: Tuple[float, float]
228
+ initial: float
229
+ over: Tuple[int, int]
230
+ idx: int
231
+
232
+
233
+ class ByofSpec(TypedDict, total=False):
234
+ """Bring-Your-Own-Functions specification for expert users.
235
+
236
+ Allows bypassing the symbolic layer and directly providing raw JAX functions.
237
+ All fields are optional - you can mix symbolic and byof as needed.
238
+
239
+ Warning:
240
+ You are responsible for:
241
+
242
+ - Correct indexing into unified state/control vectors
243
+ - Ensuring functions are JAX-compatible (use jax.numpy, no side effects)
244
+ - Ensuring functions are differentiable
245
+ - Following g(x,u) <= 0 convention for constraints
246
+
247
+ Tip:
248
+ Use the ``.slice`` property on State/Control objects for cleaner, more
249
+ maintainable indexing instead of hardcoded indices. For example, use
250
+ ``x[velocity.slice]`` instead of ``x[2:3]``. The slice property is set
251
+ after preprocessing and provides the correct indices into the unified
252
+ state/control vectors.
253
+
254
+ Attributes:
255
+ dynamics: Raw JAX functions for state derivatives. Maps state names to functions
256
+ with signature ``(x, u, node, params) -> xdot_component``. States here should
257
+ NOT appear in symbolic dynamics dict. You can mix: some states symbolic,
258
+ some in byof.
259
+ nodal_constraints: Point-wise constraints applied at specific nodes.
260
+ Each item is a :class:`NodalConstraintSpec` dict with:
261
+
262
+ - ``func``: Constraint function ``(x, u, node, params) -> residual`` (required)
263
+ - ``nodes``: List of node indices (optional, defaults to all nodes)
264
+
265
+ Follows g(x,u) <= 0 convention.
266
+ cross_nodal_constraints: Constraints coupling multiple nodes (smoothness, rate limits).
267
+ Signature: ``(X, U, params) -> residual`` where X is (N, n_x) and U is (N, n_u).
268
+ N is the number of trajectory nodes, n_x is state dimension, n_u is control dimension.
269
+ Follows g(X,U) <= 0 convention.
270
+ ctcs_constraints: Continuous-time constraint satisfaction via dynamics augmentation.
271
+ Each adds an augmented state accumulating violation penalties.
272
+ See :class:`CtcsConstraintSpec` for details.
273
+
274
+ Example:
275
+ Custom dynamics and constraints::
276
+
277
+ import jax.numpy as jnp
278
+ import openscvx as ox
279
+ from openscvx import ByofSpec
280
+
281
+ # Define states and controls
282
+ position = ox.State("position", shape=(2,))
283
+ velocity = ox.State("velocity", shape=(1,))
284
+ theta = ox.Control("theta", shape=(1,))
285
+
286
+ # Custom dynamics for one state using .slice property
287
+ def custom_velocity_dynamics(x, u, node, params):
288
+ # Use .slice property for clean indexing
289
+ return params["g"] * jnp.cos(u[theta.slice][0])
290
+
291
+ byof: ByofSpec = {
292
+ "dynamics": {
293
+ "velocity": custom_velocity_dynamics,
294
+ },
295
+ "nodal_constraints": [
296
+ # Applied to all nodes (no "nodes" field)
297
+ {
298
+ "constraint_fn": lambda x, u, node, params: x[velocity.slice][0] - 10.0,
299
+ },
300
+ {
301
+ "constraint_fn": lambda x, u, node, params: -x[velocity.slice][0],
302
+ },
303
+ # Specify nodes for selective enforcement
304
+ {
305
+ "constraint_fn": lambda x, u, node, params: x[velocity.slice][0],
306
+ "nodes": [0], # Velocity must be exactly 0 at start
307
+ },
308
+ ],
309
+ "cross_nodal_constraints": [
310
+ # Constrain total velocity across trajectory: sum(velocities) >= 5
311
+ # X.shape = (N, n_x), extract velocity column using slice
312
+ lambda X, U, params: 5.0 - jnp.sum(X[:, velocity.slice]),
313
+ ],
314
+ "ctcs_constraints": [
315
+ {
316
+ "constraint_fn": lambda x, u, node, params: x[position.slice][0] - 5.0,
317
+ "penalty": "square",
318
+ }
319
+ ],
320
+ }
321
+ """
322
+
323
+ dynamics: dict[str, DynamicsFunction]
324
+ nodal_constraints: List[NodalConstraintSpec]
325
+ cross_nodal_constraints: List[CrossNodalConstraintFunction]
326
+ ctcs_constraints: List[CtcsConstraintSpec]