openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,419 @@
1
+ """Lowering logic for bring-your-own-functions (byof).
2
+
3
+ This module handles integration of user-provided JAX functions into the
4
+ lowered problem representation, including dynamics splicing and constraint
5
+ addition.
6
+ """
7
+
8
+ from typing import TYPE_CHECKING, List, Tuple
9
+
10
+ import jax
11
+ import jax.numpy as jnp
12
+ import numpy as np
13
+ from jax import jacfwd
14
+ from jax.lax import cond
15
+
16
+ if TYPE_CHECKING:
17
+ from openscvx.lowered.unified import UnifiedState
18
+ from openscvx.symbolic.expr.state import State
19
+
20
+ from openscvx.lowered import (
21
+ Dynamics,
22
+ LoweredCrossNodeConstraint,
23
+ LoweredJaxConstraints,
24
+ LoweredNodalConstraint,
25
+ )
26
+
27
+ __all__ = ["apply_byof"]
28
+
29
+
30
+ def apply_byof(
31
+ byof: dict,
32
+ dynamics: Dynamics,
33
+ dynamics_prop: Dynamics,
34
+ jax_constraints: LoweredJaxConstraints,
35
+ x_unified: "UnifiedState",
36
+ x_prop_unified: "UnifiedState",
37
+ u_unified: "UnifiedState",
38
+ states: List["State"],
39
+ states_prop: List["State"],
40
+ N: int,
41
+ ) -> Tuple[Dynamics, Dynamics, LoweredJaxConstraints, "UnifiedState", "UnifiedState"]:
42
+ """Apply bring-your-own-functions (byof) to augment lowered problem.
43
+
44
+ Handles raw JAX functions provided by expert users, including:
45
+ - dynamics: Raw JAX functions for specific state derivatives
46
+ - nodal_constraints: Point-wise constraints at each node
47
+ - cross_nodal_constraints: Constraints coupling multiple nodes
48
+ - ctcs_constraints: Continuous-time constraint satisfaction via dynamics augmentation
49
+
50
+ Args:
51
+ byof: Dict with keys "dynamics", "nodal_constraints", "cross_nodal_constraints",
52
+ "ctcs_constraints"
53
+ dynamics: Lowered optimization dynamics to potentially augment
54
+ dynamics_prop: Lowered propagation dynamics to potentially augment
55
+ jax_constraints: Lowered JAX constraints to append to
56
+ x_unified: Unified optimization state interface to potentially augment
57
+ x_prop_unified: Unified propagation state interface to potentially augment
58
+ u_unified: Unified control interface for validation
59
+ states: List of State objects for optimization (with _slice attributes)
60
+ states_prop: List of State objects for propagation (with _slice attributes)
61
+ N: Number of nodes in the trajectory
62
+
63
+ Returns:
64
+ Tuple of (dynamics, dynamics_prop, jax_constraints, x_unified, x_prop_unified)
65
+
66
+ Example:
67
+ >>> dynamics, dynamics_prop, constraints, x_unified, x_prop_unified = apply_byof(
68
+ ... byof, dynamics, dynamics_prop, jax_constraints,
69
+ ... x_unified, x_prop_unified, u_unified, states, states_prop, N
70
+ ... )
71
+ """
72
+
73
+ # Note: byof validation happens earlier in Problem.__init__ to fail fast
74
+ # Handle byof dynamics by splicing in raw JAX functions at the correct slices
75
+ byof_dynamics = byof.get("dynamics", {})
76
+ if byof_dynamics:
77
+ # Build mapping from state name to slice for optimization states
78
+ state_slices = {state.name: state._slice for state in states}
79
+ state_slices_prop = {state.name: state._slice for state in states_prop}
80
+
81
+ def _make_composite_dynamics(orig_f, byof_fns, slices_map):
82
+ """Create composite dynamics combining symbolic and byof state derivatives.
83
+
84
+ This factory splices user-provided byof dynamics into the unified dynamics
85
+ function at the appropriate slice indices, replacing the symbolic dynamics
86
+ for specific states while preserving the rest.
87
+
88
+ Args:
89
+ orig_f: Original unified dynamics (x, u, node, params) -> xdot
90
+ byof_fns: Dict mapping state names to byof dynamics functions
91
+ slices_map: Dict mapping state names to slice objects for indexing
92
+
93
+ Returns:
94
+ Composite dynamics function with byof derivatives spliced in
95
+ """
96
+
97
+ def composite_f(x, u, node, params):
98
+ # Start with symbolic/default dynamics for all states
99
+ xdot = orig_f(x, u, node, params)
100
+
101
+ # Splice in byof dynamics for specific states
102
+ for state_name, byof_fn in byof_fns.items():
103
+ sl = slices_map[state_name]
104
+ # Replace the derivative for this state with the byof result
105
+ xdot = xdot.at[sl].set(byof_fn(x, u, node, params))
106
+
107
+ return xdot
108
+
109
+ return composite_f
110
+
111
+ # Create composite optimization dynamics
112
+ composite_f = _make_composite_dynamics(dynamics.f, byof_dynamics, state_slices)
113
+ dynamics = Dynamics(
114
+ f=composite_f,
115
+ A=jacfwd(composite_f, argnums=0),
116
+ B=jacfwd(composite_f, argnums=1),
117
+ )
118
+
119
+ # Create composite propagation dynamics
120
+ composite_f_prop = _make_composite_dynamics(
121
+ dynamics_prop.f, byof_dynamics, state_slices_prop
122
+ )
123
+ dynamics_prop = Dynamics(
124
+ f=composite_f_prop,
125
+ A=jacfwd(composite_f_prop, argnums=0),
126
+ B=jacfwd(composite_f_prop, argnums=1),
127
+ )
128
+
129
+ # Handle nodal constraints
130
+ # Note: Validation happens earlier in Problem.__init__ via validate_byof
131
+ for constraint_spec in byof.get("nodal_constraints", []):
132
+ fn = constraint_spec["constraint_fn"]
133
+ nodes = constraint_spec.get("nodes", list(range(N))) # Default: all nodes
134
+
135
+ # Normalize negative node indices (validation already done in validate_byof)
136
+ normalized_nodes = [node if node >= 0 else N + node for node in nodes]
137
+
138
+ constraint = LoweredNodalConstraint(
139
+ func=jax.vmap(fn, in_axes=(0, 0, None, None)),
140
+ grad_g_x=jax.vmap(jacfwd(fn, argnums=0), in_axes=(0, 0, None, None)),
141
+ grad_g_u=jax.vmap(jacfwd(fn, argnums=1), in_axes=(0, 0, None, None)),
142
+ nodes=normalized_nodes,
143
+ )
144
+ jax_constraints.nodal.append(constraint)
145
+
146
+ # Handle cross-nodal constraints
147
+ for fn in byof.get("cross_nodal_constraints", []):
148
+ constraint = LoweredCrossNodeConstraint(
149
+ func=fn,
150
+ grad_g_X=jacfwd(fn, argnums=0),
151
+ grad_g_U=jacfwd(fn, argnums=1),
152
+ )
153
+ jax_constraints.cross_node.append(constraint)
154
+
155
+ # Handle CTCS constraints by augmenting dynamics
156
+ # Built-in penalty functions
157
+ def _penalty_square(r):
158
+ return jnp.maximum(r, 0.0) ** 2
159
+
160
+ def _penalty_l1(r):
161
+ return jnp.maximum(r, 0.0)
162
+
163
+ def _penalty_huber(r, delta=1.0):
164
+ abs_r = jnp.maximum(r, 0.0)
165
+ return jnp.where(abs_r <= delta, 0.5 * abs_r**2, delta * (abs_r - 0.5 * delta))
166
+
167
+ _PENALTY_FUNCTIONS = {
168
+ "square": _penalty_square,
169
+ "l1": _penalty_l1,
170
+ "huber": _penalty_huber,
171
+ }
172
+
173
+ # Determine which symbolic CTCS idx values already exist
174
+ # Symbolic augmented states are named "_ctcs_aug_{i}" where i is sequential
175
+ # and corresponds to sorted symbolic idx values (0, 1, 2, ...)
176
+ symbolic_ctcs_idx = []
177
+ for state in states:
178
+ if state.name.startswith("_ctcs_aug_"):
179
+ try:
180
+ aug_idx = int(state.name.split("_")[-1])
181
+ symbolic_ctcs_idx.append(aug_idx)
182
+ except (ValueError, IndexError):
183
+ pass
184
+
185
+ # Symbolic CTCS creates augmented states with sequential idx: 0, 1, 2, ...
186
+ # so max_symbolic_idx = len(symbolic_ctcs_idx) - 1 (or -1 if none exist)
187
+ max_symbolic_idx = len(symbolic_ctcs_idx) - 1 if symbolic_ctcs_idx else -1
188
+
189
+ # Build idx -> augmented_state_slice mapping for existing symbolic CTCS
190
+ # Augmented states appear after regular states in the unified vector
191
+ # We'll determine the slice by finding the state in the states list
192
+ idx_to_aug_slice = {}
193
+ for state in states:
194
+ if state.name.startswith("_ctcs_aug_"):
195
+ try:
196
+ aug_idx = int(state.name.split("_")[-1])
197
+ # The actual idx value IS the sequential index for symbolic CTCS
198
+ # (they're created with idx 0, 1, 2, ... in sorted order)
199
+ idx_to_aug_slice[aug_idx] = state._slice
200
+ except (ValueError, IndexError, AttributeError):
201
+ pass
202
+
203
+ # Group BYOF CTCS constraints by idx (default to 0)
204
+ byof_ctcs_groups = {}
205
+ for ctcs_spec in byof.get("ctcs_constraints", []):
206
+ idx = ctcs_spec.get("idx", 0)
207
+ if idx not in byof_ctcs_groups:
208
+ byof_ctcs_groups[idx] = []
209
+ byof_ctcs_groups[idx].append(ctcs_spec)
210
+
211
+ # Validate that byof idx values don't create gaps
212
+ # All idx must form contiguous sequence: [0, 1, 2, ..., max_idx]
213
+ if byof_ctcs_groups:
214
+ all_idx = sorted(set(range(max_symbolic_idx + 1)) | set(byof_ctcs_groups.keys()))
215
+ expected_idx = list(range(len(all_idx)))
216
+ if all_idx != expected_idx:
217
+ raise ValueError(
218
+ f"BYOF CTCS idx values create non-contiguous sequence. "
219
+ f"Symbolic CTCS has idx=[{', '.join(map(str, range(max_symbolic_idx + 1)))}], "
220
+ f"combined with byof idx={sorted(byof_ctcs_groups.keys())} gives {all_idx}. "
221
+ f"Expected contiguous sequence {expected_idx}. "
222
+ f"Byof idx must either match existing symbolic idx or be sequential after them."
223
+ )
224
+
225
+ # Process each idx group
226
+ for idx in sorted(byof_ctcs_groups.keys()):
227
+ specs = byof_ctcs_groups[idx]
228
+
229
+ # Collect all penalty functions for this idx
230
+ penalty_fns = []
231
+ for spec in specs:
232
+ constraint_fn = spec["constraint_fn"]
233
+ penalty_spec = spec.get("penalty", "square")
234
+ over_interval = spec.get("over", None) # Node interval (start, end) or None
235
+
236
+ if callable(penalty_spec):
237
+ penalty_func = penalty_spec
238
+ else:
239
+ penalty_func = _PENALTY_FUNCTIONS[penalty_spec]
240
+
241
+ # Create a combined constraint+penalty function
242
+ def _make_penalty_fn(cons_fn, pen_func, over):
243
+ """Factory to capture constraint, penalty functions, and node interval.
244
+
245
+ Args:
246
+ cons_fn: Constraint function (x, u, node, params) -> scalar residual
247
+ pen_func: Penalty function (residual) -> penalty value
248
+ over: Optional (start, end) tuple for conditional activation
249
+
250
+ Returns:
251
+ Penalty function that conditionally activates based on node interval
252
+ """
253
+
254
+ def penalty_fn(x, u, node, params):
255
+ # Compute penalty for the constraint violation
256
+ residual = cons_fn(x, u, node, params)
257
+ penalty_value = pen_func(residual)
258
+
259
+ # Apply conditional logic if over interval is specified
260
+ if over is not None:
261
+ start_node, end_node = over
262
+ # Extract scalar from node (which may be array or scalar)
263
+ # Keep as JAX array for tracing compatibility
264
+ node_scalar = jnp.atleast_1d(node)[0]
265
+ is_active = (start_node <= node_scalar) & (node_scalar < end_node)
266
+
267
+ # Use jax.lax.cond for JAX-traceable conditional evaluation
268
+ # Penalty is active only when node is in [start, end)
269
+ return cond(
270
+ is_active,
271
+ lambda _: penalty_value,
272
+ lambda _: 0.0,
273
+ operand=None,
274
+ )
275
+ else:
276
+ # Always active if no interval specified
277
+ return penalty_value
278
+
279
+ return penalty_fn
280
+
281
+ penalty_fns.append(_make_penalty_fn(constraint_fn, penalty_func, over_interval))
282
+
283
+ if idx in idx_to_aug_slice:
284
+ # This idx already exists from symbolic CTCS - add penalties to existing state
285
+ aug_slice = idx_to_aug_slice[idx]
286
+
287
+ def _make_ctcs_addition(orig_f, pen_fns, aug_sl):
288
+ """Create dynamics that adds penalties to existing augmented state.
289
+
290
+ Args:
291
+ orig_f: Original dynamics function
292
+ pen_fns: List of penalty functions to add
293
+ aug_sl: Slice of the augmented state to modify
294
+
295
+ Returns:
296
+ Modified dynamics function
297
+ """
298
+
299
+ def modified_f(x, u, node, params):
300
+ xdot = orig_f(x, u, node, params)
301
+
302
+ # Sum all penalties for this idx
303
+ total_penalty = sum(pen_fn(x, u, node, params) for pen_fn in pen_fns)
304
+
305
+ # Add to existing augmented state derivative
306
+ current_deriv = xdot[aug_sl]
307
+ xdot = xdot.at[aug_sl].set(current_deriv + total_penalty)
308
+
309
+ return xdot
310
+
311
+ return modified_f
312
+
313
+ # Modify both optimization and propagation dynamics
314
+ dynamics.f = _make_ctcs_addition(dynamics.f, penalty_fns, aug_slice)
315
+ dynamics.A = jacfwd(dynamics.f, argnums=0)
316
+ dynamics.B = jacfwd(dynamics.f, argnums=1)
317
+
318
+ dynamics_prop.f = _make_ctcs_addition(dynamics_prop.f, penalty_fns, aug_slice)
319
+ dynamics_prop.A = jacfwd(dynamics_prop.f, argnums=0)
320
+ dynamics_prop.B = jacfwd(dynamics_prop.f, argnums=1)
321
+
322
+ else:
323
+ # New idx - create new augmented state
324
+ # Use bounds/initial from first spec in this group
325
+ first_spec = specs[0]
326
+ bounds = first_spec.get("bounds", (0.0, 1e-4))
327
+ initial = first_spec.get("initial", bounds[0])
328
+
329
+ def _make_ctcs_new_state(orig_f, pen_fns):
330
+ """Create dynamics augmented with new CTCS state.
331
+
332
+ Args:
333
+ orig_f: Original dynamics function
334
+ pen_fns: List of penalty functions to sum
335
+
336
+ Returns:
337
+ Augmented dynamics function
338
+ """
339
+
340
+ def augmented_f(x, u, node, params):
341
+ xdot = orig_f(x, u, node, params)
342
+
343
+ # Sum all penalties for this new idx
344
+ total_penalty = sum(pen_fn(x, u, node, params) for pen_fn in pen_fns)
345
+
346
+ # Append as new augmented state derivative
347
+ return jnp.concatenate([xdot, jnp.atleast_1d(total_penalty)])
348
+
349
+ return augmented_f
350
+
351
+ # Augment optimization dynamics
352
+ aug_f = _make_ctcs_new_state(dynamics.f, penalty_fns)
353
+ dynamics = Dynamics(
354
+ f=aug_f,
355
+ A=jacfwd(aug_f, argnums=0),
356
+ B=jacfwd(aug_f, argnums=1),
357
+ )
358
+
359
+ # Augment propagation dynamics
360
+ aug_f_prop = _make_ctcs_new_state(dynamics_prop.f, penalty_fns)
361
+ dynamics_prop = Dynamics(
362
+ f=aug_f_prop,
363
+ A=jacfwd(aug_f_prop, argnums=0),
364
+ B=jacfwd(aug_f_prop, argnums=1),
365
+ )
366
+
367
+ # Create State objects for the new augmented states
368
+ # This is necessary for CVXPy variable creation and other bookkeeping
369
+ from openscvx.symbolic.expr.state import State
370
+
371
+ # Create augmented state for optimization
372
+ aug_state = State(f"_ctcs_aug_{idx}", shape=(1,))
373
+ aug_state.min = np.array([bounds[0]])
374
+ aug_state.max = np.array([bounds[1]])
375
+ aug_state.initial = np.array([initial])
376
+ aug_state.final = [("free", 0.0)]
377
+ aug_state.guess = np.full((N, 1), initial)
378
+
379
+ # Set _slice attribute for the new state
380
+ current_dim = x_unified.shape[0]
381
+ aug_state._slice = slice(current_dim, current_dim + 1)
382
+
383
+ # Append to states list (in-place modification visible to caller)
384
+ states.append(aug_state)
385
+
386
+ # Create augmented state for propagation
387
+ aug_state_prop = State(f"_ctcs_aug_{idx}", shape=(1,))
388
+ aug_state_prop.min = np.array([bounds[0]])
389
+ aug_state_prop.max = np.array([bounds[1]])
390
+ aug_state_prop.initial = np.array([initial])
391
+ aug_state_prop.final = [("free", 0.0)]
392
+ aug_state_prop.guess = np.full((N, 1), initial)
393
+
394
+ # Set _slice attribute for the propagation state
395
+ current_dim_prop = x_prop_unified.shape[0]
396
+ aug_state_prop._slice = slice(current_dim_prop, current_dim_prop + 1)
397
+
398
+ # Append to states_prop list
399
+ states_prop.append(aug_state_prop)
400
+
401
+ # Add new augmented states to both unified state interfaces
402
+ x_unified.append(
403
+ min=bounds[0],
404
+ max=bounds[1],
405
+ guess=initial,
406
+ initial=initial,
407
+ final=0.0,
408
+ augmented=True,
409
+ )
410
+ x_prop_unified.append(
411
+ min=bounds[0],
412
+ max=bounds[1],
413
+ guess=initial,
414
+ initial=initial,
415
+ final=0.0,
416
+ augmented=True,
417
+ )
418
+
419
+ return dynamics, dynamics_prop, jax_constraints, x_unified, x_prop_unified