openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1382 @@
1
+ """JAX backend for lowering symbolic expressions to executable functions.
2
+
3
+ This module implements the JAX lowering backend that converts symbolic expression
4
+ AST nodes into JAX functions with automatic differentiation support. The lowering
5
+ uses a visitor pattern where each expression type has a corresponding visitor method.
6
+
7
+ Architecture:
8
+ The JAX lowerer follows a visitor pattern with centralized registration:
9
+
10
+ 1. **Visitor Registration**: The @visitor decorator registers handler functions
11
+ for each expression type in the _JAX_VISITORS dictionary
12
+ 2. **Dispatch**: The dispatch() function looks up and calls the appropriate
13
+ visitor based on the expression's type
14
+ 3. **Recursive Lowering**: Each visitor recursively lowers child expressions
15
+ and composes JAX operations
16
+ 4. **Standardized Signature**: All lowered functions have signature
17
+ (x, u, node, params) -> result for uniformity
18
+
19
+ Key Features:
20
+ - **Automatic Differentiation**: Lowered functions can be differentiated using
21
+ JAX's jacfwd/jacrev for computing Jacobians
22
+ - **JIT Compilation**: All functions are JAX-traceable and JIT-compatible
23
+ - **Functional Closures**: Each visitor returns a closure that captures
24
+ necessary constants and child functions
25
+ - **Broadcasting**: Supports NumPy-style broadcasting through jnp operations
26
+
27
+ Lowered Function Signature:
28
+ All lowered functions have a uniform signature::
29
+
30
+ f(x, u, node, params) -> result
31
+
32
+ Where:
33
+
34
+ - x: State vector (jnp.ndarray)
35
+ - u: Control vector (jnp.ndarray)
36
+ - node: Node index for time-varying behavior (scalar or array)
37
+ - params: Dictionary of parameter values (dict[str, Any])
38
+ - result: JAX array (scalar, vector, or matrix)
39
+
40
+ Example:
41
+ Basic usage::
42
+
43
+ from openscvx.symbolic.lowerers.jax import JaxLowerer
44
+ import openscvx as ox
45
+
46
+ # Create symbolic expression
47
+ x = ox.State("x", shape=(3,))
48
+ u = ox.Control("u", shape=(2,))
49
+ expr = ox.Norm(x)**2 + 0.1 * ox.Norm(u)**2
50
+
51
+ # Lower to JAX
52
+ lowerer = JaxLowerer()
53
+ f = lowerer.lower(expr)
54
+
55
+ # Evaluate
56
+ import jax.numpy as jnp
57
+ x_val = jnp.array([1.0, 2.0, 3.0])
58
+ u_val = jnp.array([0.5, 0.5])
59
+ result = f(x_val, u_val, node=0, params={})
60
+
61
+ # Differentiate
62
+ from jax import jacfwd
63
+ df_dx = jacfwd(f, argnums=0)
64
+ gradient = df_dx(x_val, u_val, node=0, params={})
65
+
66
+ For Contributors:
67
+ **Adding Support for New Expression Types**
68
+
69
+ To add support for a new symbolic expression type to JAX lowering:
70
+
71
+ 1. **Define the visitor method** in JaxLowerer with the @visitor decorator::
72
+
73
+ @visitor(MyNewExpr)
74
+ def _visit_my_new_expr(self, node: MyNewExpr):
75
+ # Lower child expressions recursively
76
+ operand_fn = self.lower(node.operand)
77
+
78
+ # Return a closure with signature (x, u, node, params) -> result
79
+ return lambda x, u, node, params: jnp.my_operation(
80
+ operand_fn(x, u, node, params)
81
+ )
82
+
83
+ 2. **Key requirements**:
84
+ - Use the @visitor(ExprType) decorator to register the handler
85
+ - Method name should be _visit_<expr_name> (private, lowercase, snake_case)
86
+ - Recursively lower all child expressions using self.lower()
87
+ - Return a closure with signature (x, u, node, params) -> jax_array
88
+ - Use jnp.* operations (not np.*) for JAX traceability
89
+ - Ensure the result is JAX-differentiable (avoid Python control flow)
90
+
91
+ 3. **Example patterns**:
92
+ - Unary operation: Lower operand, apply jnp function
93
+ - Binary operation: Lower both operands, combine with jnp operation
94
+ - N-ary operation: Lower all operands, reduce or combine them
95
+ - Conditional logic: Use jax.lax.cond for branching (see _visit_ctcs)
96
+
97
+ 4. **Testing**: Ensure your visitor works with:
98
+ - JAX JIT compilation: jax.jit(lowered_fn)
99
+ - Automatic differentiation: jax.jacfwd(lowered_fn, argnums=0)
100
+ - Vectorization: jax.vmap(lowered_fn)
101
+
102
+ See Also:
103
+ - lower_to_jax(): Convenience wrapper in symbolic/lower.py
104
+ - CVXPyLowerer: Alternative backend for convex constraints
105
+ - dispatch(): Core dispatch function for visitor pattern
106
+ """
107
+
108
+ from typing import Any, Callable, Dict, Type
109
+
110
+ import jax.numpy as jnp
111
+ from jax.lax import cond
112
+ from jax.scipy.special import logsumexp
113
+
114
+ from openscvx.symbolic.expr import (
115
+ CTCS,
116
+ QDCM,
117
+ SSM,
118
+ SSMP,
119
+ Abs,
120
+ Add,
121
+ Adjoint,
122
+ AdjointDual,
123
+ Block,
124
+ Concat,
125
+ Constant,
126
+ Constraint,
127
+ Cos,
128
+ CrossNodeConstraint,
129
+ Diag,
130
+ Div,
131
+ Equality,
132
+ Exp,
133
+ Expr,
134
+ Hstack,
135
+ Huber,
136
+ Index,
137
+ Inequality,
138
+ Log,
139
+ LogSumExp,
140
+ MatMul,
141
+ Max,
142
+ Mul,
143
+ Neg,
144
+ NodalConstraint,
145
+ NodeReference,
146
+ Norm,
147
+ Or,
148
+ Parameter,
149
+ PositivePart,
150
+ Power,
151
+ SE3Adjoint,
152
+ SE3AdjointDual,
153
+ Sin,
154
+ SmoothReLU,
155
+ Sqrt,
156
+ Square,
157
+ Stack,
158
+ Sub,
159
+ Sum,
160
+ Tan,
161
+ Transpose,
162
+ Vstack,
163
+ )
164
+ from openscvx.symbolic.expr.control import Control
165
+ from openscvx.symbolic.expr.lie import (
166
+ SE3Exp,
167
+ SE3Log,
168
+ SO3Exp,
169
+ SO3Log,
170
+ )
171
+ from openscvx.symbolic.expr.state import State
172
+
173
+ _JAX_VISITORS: Dict[Type[Expr], Callable] = {}
174
+ """Registry mapping expression types to their visitor functions."""
175
+
176
+
177
+ def visitor(expr_cls: Type[Expr]):
178
+ """Decorator to register a visitor function for an expression type.
179
+
180
+ This decorator registers a visitor method to handle a specific expression
181
+ type during JAX lowering. The decorated function is stored in _JAX_VISITORS
182
+ and will be called by dispatch() when lowering that expression type.
183
+
184
+ Args:
185
+ expr_cls: The Expr subclass this visitor handles (e.g., Add, Mul, Norm)
186
+
187
+ Returns:
188
+ Decorator function that registers the visitor and returns it unchanged
189
+
190
+ Example:
191
+ Register a visitor function for the Add expression:
192
+
193
+ @visitor(Add)
194
+ def _visit_add(self, node: Add):
195
+ # Lower addition to JAX
196
+ ...
197
+
198
+ Note:
199
+ Multiple expression types can share a visitor by stacking decorators::
200
+
201
+ @visitor(Equality)
202
+ @visitor(Inequality)
203
+ def _visit_constraint(self, node: Constraint):
204
+ # Handle both equality and inequality
205
+ ...
206
+ """
207
+
208
+ def register(fn: Callable[[Any, Expr], Callable]):
209
+ _JAX_VISITORS[expr_cls] = fn
210
+ return fn
211
+
212
+ return register
213
+
214
+
215
+ def dispatch(lowerer: Any, expr: Expr):
216
+ """Dispatch an expression to its registered visitor function.
217
+
218
+ Looks up the visitor function for the expression's type and calls it.
219
+ This is the core of the visitor pattern implementation.
220
+
221
+ Args:
222
+ lowerer: The JaxLowerer instance (provides context for visitor methods)
223
+ expr: The expression node to lower
224
+
225
+ Returns:
226
+ The result of calling the visitor function (typically a JAX callable)
227
+
228
+ Raises:
229
+ NotImplementedError: If no visitor is registered for the expression type
230
+
231
+ Example:
232
+ Dispatch an expression to lower it to a JAX function:
233
+
234
+ lowerer = JaxLowerer()
235
+ expr = Add(x, y)
236
+ fn = dispatch(lowerer, expr) # Calls visit_add
237
+ """
238
+ fn = _JAX_VISITORS.get(type(expr))
239
+ if fn is None:
240
+ raise NotImplementedError(
241
+ f"{lowerer.__class__.__name__!r} has no visitor for {type(expr).__name__}"
242
+ )
243
+ return fn(lowerer, expr)
244
+
245
+
246
+ class JaxLowerer:
247
+ """JAX backend for lowering symbolic expressions to executable functions.
248
+
249
+ This class implements the visitor pattern for converting symbolic expression
250
+ AST nodes to JAX functions. Each expression type has a corresponding visitor
251
+ method decorated with @visitor that handles the lowering logic.
252
+
253
+ The lowering process is recursive: each visitor lowers its child expressions
254
+ first, then composes them into a JAX operation. All lowered functions have
255
+ a standardized signature (x, u, node, params) -> result.
256
+
257
+ Attributes:
258
+ None (stateless lowerer - all state is in the expression tree)
259
+
260
+ Example:
261
+ Set up the JaxLowerer and lower an expression to a JAX function:
262
+
263
+ lowerer = JaxLowerer()
264
+ expr = ox.Norm(x)**2 + 0.1 * ox.Norm(u)**2
265
+ f = lowerer.lower(expr)
266
+ result = f(x_val, u_val, node=0, params={})
267
+
268
+ Note:
269
+ The lowerer is stateless and can be reused for multiple expressions.
270
+ All visitor methods are instance methods to maintain a clean interface,
271
+ but they don't modify instance state.
272
+ """
273
+
274
+ def lower(self, expr: Expr):
275
+ """Lower a symbolic expression to a JAX function.
276
+
277
+ Main entry point for lowering. Delegates to dispatch() which looks up
278
+ the appropriate visitor method based on the expression type.
279
+
280
+ Args:
281
+ expr: Symbolic expression to lower (any Expr subclass)
282
+
283
+ Returns:
284
+ JAX function with signature (x, u, node, params) -> result
285
+
286
+ Raises:
287
+ NotImplementedError: If no visitor exists for the expression type
288
+ ValueError: If the expression is malformed (e.g., State without slice)
289
+
290
+ Example:
291
+ Lower an expression to a JAX function:
292
+
293
+ lowerer = JaxLowerer()
294
+ x = ox.State("x", shape=(3,))
295
+ expr = ox.Norm(x)
296
+ f = lowerer.lower(expr)
297
+ # f is now callable
298
+ """
299
+ return dispatch(self, expr)
300
+
301
+ @visitor(Constant)
302
+ def _visit_constant(self, node: Constant):
303
+ """Lower a constant value to a JAX function.
304
+
305
+ Captures the constant value and returns a function that always returns it.
306
+ Scalar constants are squeezed to ensure they're true scalars, not (1,) arrays.
307
+
308
+ Args:
309
+ node: Constant expression node
310
+
311
+ Returns:
312
+ Function (x, u, node, params) -> constant_value
313
+ """
314
+ # capture the constant value once
315
+ value = jnp.array(node.value)
316
+ # For scalar constants (single element arrays), squeeze to scalar
317
+ # This prevents (1,) shapes in constraint residuals
318
+ if value.size == 1:
319
+ value = value.squeeze()
320
+ return lambda x, u, node, params: value
321
+
322
+ @visitor(State)
323
+ def _visit_state(self, node: State):
324
+ """Lower a state variable to a JAX function.
325
+
326
+ Extracts the appropriate slice from the unified state vector x using
327
+ the slice assigned during unification.
328
+
329
+ Args:
330
+ node: State expression node
331
+
332
+ Returns:
333
+ Function (x, u, node, params) -> x[slice]
334
+
335
+ Raises:
336
+ ValueError: If the state has no slice assigned (unification not run)
337
+ """
338
+ sl = node._slice
339
+ if sl is None:
340
+ raise ValueError(f"State {node.name!r} has no slice assigned")
341
+ return lambda x, u, node, params: x[sl]
342
+
343
+ @visitor(Control)
344
+ def _visit_control(self, node: Control):
345
+ """Lower a control variable to a JAX function.
346
+
347
+ Extracts the appropriate slice from the unified control vector u using
348
+ the slice assigned during unification.
349
+
350
+ Args:
351
+ node: Control expression node
352
+
353
+ Returns:
354
+ Function (x, u, node, params) -> u[slice]
355
+
356
+ Raises:
357
+ ValueError: If the control has no slice assigned (unification not run)
358
+ """
359
+ sl = node._slice
360
+ if sl is None:
361
+ raise ValueError(f"Control {node.name!r} has no slice assigned")
362
+ return lambda x, u, node, params: u[sl]
363
+
364
+ @visitor(Parameter)
365
+ def _visit_parameter(self, node: Parameter):
366
+ """Lower a parameter to a JAX function.
367
+
368
+ Parameters are looked up by name in the params dictionary at evaluation time,
369
+ allowing runtime parameter updates without recompilation.
370
+
371
+ Args:
372
+ node: Parameter expression node
373
+
374
+ Returns:
375
+ Function (x, u, node, params) -> params[name]
376
+ """
377
+ param_name = node.name
378
+ return lambda x, u, node, params: jnp.array(params[param_name])
379
+
380
+ @visitor(Add)
381
+ def _visit_add(self, node: Add):
382
+ """Lower addition to JAX function.
383
+
384
+ Recursively lowers all terms and composes them with element-wise addition.
385
+ Supports broadcasting following NumPy/JAX rules.
386
+
387
+ Args:
388
+ node: Add expression node with multiple terms
389
+
390
+ Returns:
391
+ Function (x, u, node, params) -> sum of all terms
392
+ """
393
+ fs = [self.lower(term) for term in node.terms]
394
+
395
+ def fn(x, u, node, params):
396
+ acc = fs[0](x, u, node, params)
397
+ for f in fs[1:]:
398
+ acc = acc + f(x, u, node, params)
399
+ return acc
400
+
401
+ return fn
402
+
403
+ @visitor(Sub)
404
+ def _visit_sub(self, node: Sub):
405
+ """Lower subtraction to JAX function (element-wise left - right)."""
406
+ fL = self.lower(node.left)
407
+ fR = self.lower(node.right)
408
+ return lambda x, u, node, params: fL(x, u, node, params) - fR(x, u, node, params)
409
+
410
+ @visitor(Mul)
411
+ def _visit_mul(self, node: Mul):
412
+ """Lower element-wise multiplication to JAX function (Hadamard product)."""
413
+ fs = [self.lower(factor) for factor in node.factors]
414
+
415
+ def fn(x, u, node, params):
416
+ acc = fs[0](x, u, node, params)
417
+ for f in fs[1:]:
418
+ acc = acc * f(x, u, node, params)
419
+ return acc
420
+
421
+ return fn
422
+
423
+ @visitor(Div)
424
+ def _visit_div(self, node: Div):
425
+ """Lower element-wise division to JAX function."""
426
+ fL = self.lower(node.left)
427
+ fR = self.lower(node.right)
428
+ return lambda x, u, node, params: fL(x, u, node, params) / fR(x, u, node, params)
429
+
430
+ @visitor(MatMul)
431
+ def _visit_matmul(self, node: MatMul):
432
+ """Lower matrix multiplication to JAX function using jnp.matmul."""
433
+ fL = self.lower(node.left)
434
+ fR = self.lower(node.right)
435
+ return lambda x, u, node, params: jnp.matmul(fL(x, u, node, params), fR(x, u, node, params))
436
+
437
+ @visitor(Neg)
438
+ def _visit_neg(self, node: Neg):
439
+ """Lower negation (unary minus) to JAX function."""
440
+ fO = self.lower(node.operand)
441
+ return lambda x, u, node, params: -fO(x, u, node, params)
442
+
443
+ @visitor(Sum)
444
+ def _visit_sum(self, node: Sum):
445
+ """Lower sum reduction to JAX function (sums all elements)."""
446
+ f = self.lower(node.operand)
447
+ return lambda x, u, node, params: jnp.sum(f(x, u, node, params))
448
+
449
+ @visitor(Norm)
450
+ def _visit_norm(self, node: Norm):
451
+ """Lower norm operation to JAX function.
452
+
453
+ Converts symbolic norm to jnp.linalg.norm with appropriate ord parameter.
454
+ Handles string ord values like "inf", "-inf", "fro".
455
+
456
+ Args:
457
+ node: Norm expression node with ord attribute
458
+
459
+ Returns:
460
+ Function (x, u, node, params) -> norm of operand
461
+ """
462
+ f = self.lower(node.operand)
463
+ ord_val = node.ord
464
+
465
+ # Convert string ord values to appropriate JAX values
466
+ if ord_val == "inf":
467
+ ord_val = jnp.inf
468
+ elif ord_val == "-inf":
469
+ ord_val = -jnp.inf
470
+ elif ord_val == "fro":
471
+ # For vectors, Frobenius norm is the same as 2-norm
472
+ ord_val = None # Default is 2-norm
473
+
474
+ return lambda x, u, node, params: jnp.linalg.norm(f(x, u, node, params), ord=ord_val)
475
+
476
+ @visitor(Index)
477
+ def _visit_index(self, node: Index):
478
+ """Lower indexing/slicing operation to JAX function.
479
+
480
+ For multi-dimensional indexing, the base array is reshaped to its
481
+ original shape before applying the index. This is necessary because
482
+ State variables are stored flattened in the state vector.
483
+ """
484
+ f_base = self.lower(node.base)
485
+ idx = node.index
486
+ base_shape = node.base.check_shape()
487
+
488
+ def index_fn(x, u, node_arg, params):
489
+ arr = f_base(x, u, node_arg, params)
490
+ # Reshape to original shape for multi-dimensional indexing
491
+ if len(base_shape) > 1:
492
+ arr = arr.reshape(base_shape)
493
+ else:
494
+ arr = jnp.atleast_1d(arr)
495
+ return arr[idx]
496
+
497
+ return index_fn
498
+
499
+ @visitor(Concat)
500
+ def _visit_concat(self, node: Concat):
501
+ """Lower concatenation to JAX function (concatenates along axis 0)."""
502
+ # lower each child
503
+ fn_list = [self.lower(child) for child in node.exprs]
504
+
505
+ # wrapper that promotes scalars to 1-D and concatenates
506
+ def concat_fn(x, u, node, params):
507
+ parts = [jnp.atleast_1d(fn(x, u, node, params)) for fn in fn_list]
508
+ return jnp.concatenate(parts, axis=0)
509
+
510
+ return concat_fn
511
+
512
+ @visitor(Sin)
513
+ def _visit_sin(self, node: Sin):
514
+ """Lower sine function to JAX function."""
515
+ fO = self.lower(node.operand)
516
+ return lambda x, u, node, params: jnp.sin(fO(x, u, node, params))
517
+
518
+ @visitor(Cos)
519
+ def _visit_cos(self, node: Cos):
520
+ """Lower cosine function to JAX function."""
521
+ fO = self.lower(node.operand)
522
+ return lambda x, u, node, params: jnp.cos(fO(x, u, node, params))
523
+
524
+ @visitor(Tan)
525
+ def _visit_tan(self, node: Tan):
526
+ """Lower tangent function to JAX function."""
527
+ fO = self.lower(node.operand)
528
+ return lambda x, u, node, params: jnp.tan(fO(x, u, node, params))
529
+
530
+ @visitor(Exp)
531
+ def _visit_exp(self, node: Exp):
532
+ """Lower exponential function to JAX function."""
533
+ fO = self.lower(node.operand)
534
+ return lambda x, u, node, params: jnp.exp(fO(x, u, node, params))
535
+
536
+ @visitor(Log)
537
+ def _visit_log(self, node: Log):
538
+ """Lower natural logarithm to JAX function."""
539
+ fO = self.lower(node.operand)
540
+ return lambda x, u, node, params: jnp.log(fO(x, u, node, params))
541
+
542
+ @visitor(Abs)
543
+ def _visit_abs(self, node: Abs):
544
+ """Lower absolute value to JAX function."""
545
+ fO = self.lower(node.operand)
546
+ return lambda x, u, node, params: jnp.abs(fO(x, u, node, params))
547
+
548
+ @visitor(Equality)
549
+ @visitor(Inequality)
550
+ def _visit_constraint(self, node: Constraint):
551
+ """Lower constraint to residual function.
552
+
553
+ Both equality (lhs == rhs) and inequality (lhs <= rhs) constraints are
554
+ lowered to their residual form: lhs - rhs. The constraint is satisfied
555
+ when the residual equals zero (equality) or is non-positive (inequality).
556
+
557
+ Args:
558
+ node: Equality or Inequality constraint node
559
+
560
+ Returns:
561
+ Function (x, u, node, params) -> lhs - rhs (constraint residual)
562
+
563
+ Note:
564
+ The returned residual is used in penalty methods and Lagrangian terms.
565
+ For equality: residual should be 0
566
+ For inequality: residual should be <= 0
567
+ """
568
+ fL = self.lower(node.lhs)
569
+ fR = self.lower(node.rhs)
570
+ return lambda x, u, node, params: fL(x, u, node, params) - fR(x, u, node, params)
571
+
572
+ # TODO: (norrisg) CTCS is playing 2 roles here: both as a constraint wrapper and as the penalty
573
+ # expression w/ conditional logic. Consider adding conditional logic as separate AST nodes.
574
+ # Then, CTCS remains a wrapper and we just wrap the penalty expression with the conditional
575
+ # logic when we lower it.
576
+ @visitor(CTCS)
577
+ def _visit_ctcs(self, node: CTCS):
578
+ """Lower CTCS (Continuous-Time Constraint Satisfaction) to JAX function.
579
+
580
+ CTCS constraints use penalty methods to enforce constraints over continuous
581
+ time intervals. The lowered function includes conditional logic to activate
582
+ the penalty only within the specified node interval.
583
+
584
+ Args:
585
+ node: CTCS constraint node with penalty expression and optional node range
586
+
587
+ Returns:
588
+ Function (x, u, current_node, params) -> penalty value or 0
589
+
590
+ Note:
591
+ Uses jax.lax.cond for JAX-traceable conditional evaluation. The penalty
592
+ is active only when current_node is in [start_node, end_node).
593
+ If no node range is specified, the penalty is always active.
594
+
595
+ See Also:
596
+ - CTCS: The symbolic CTCS constraint class
597
+ - penalty functions: PositivePart, Huber, SmoothReLU
598
+ """
599
+ # Lower the penalty expression (which includes the constraint residual)
600
+ penalty_expr_fn = self.lower(node.penalty_expr())
601
+
602
+ def ctcs_fn(x, u, current_node, params):
603
+ # Check if constraint is active at this node
604
+ if node.nodes is not None:
605
+ start_node, end_node = node.nodes
606
+ # Extract scalar value from current_node (which may be array or scalar)
607
+ # Keep as JAX array for tracing compatibility
608
+ node_scalar = jnp.atleast_1d(current_node)[0]
609
+ is_active = (start_node <= node_scalar) & (node_scalar < end_node)
610
+
611
+ # Use jax.lax.cond for conditional evaluation
612
+ return cond(
613
+ is_active,
614
+ lambda _: penalty_expr_fn(x, u, current_node, params),
615
+ lambda _: 0.0,
616
+ operand=None,
617
+ )
618
+ else:
619
+ # Always active if no node range specified
620
+ return penalty_expr_fn(x, u, current_node, params)
621
+
622
+ return ctcs_fn
623
+
624
+ @visitor(PositivePart)
625
+ def _visit_pos(self, node):
626
+ """Lower positive part function to JAX.
627
+
628
+ Computes max(x, 0), used in penalty methods for inequality constraints.
629
+
630
+ Args:
631
+ node: PositivePart expression node
632
+
633
+ Returns:
634
+ Function (x, u, node, params) -> max(operand, 0)
635
+ """
636
+ f = self.lower(node.x)
637
+ return lambda x, u, node, params: jnp.maximum(f(x, u, node, params), 0.0)
638
+
639
+ @visitor(Square)
640
+ def _visit_square(self, node):
641
+ """Lower square function to JAX.
642
+
643
+ Computes x^2 element-wise. Used in quadratic penalty methods.
644
+
645
+ Args:
646
+ node: Square expression node
647
+
648
+ Returns:
649
+ Function (x, u, node, params) -> operand^2
650
+ """
651
+ f = self.lower(node.x)
652
+ return lambda x, u, node, params: f(x, u, node, params) * f(x, u, node, params)
653
+
654
+ @visitor(Huber)
655
+ def _visit_huber(self, node):
656
+ """Lower Huber penalty function to JAX.
657
+
658
+ Huber penalty is quadratic for small values and linear for large values:
659
+ - |x| <= delta: 0.5 * x^2
660
+ - |x| > delta: delta * (|x| - 0.5 * delta)
661
+
662
+ Args:
663
+ node: Huber expression node with delta parameter
664
+
665
+ Returns:
666
+ Function (x, u, node, params) -> Huber penalty
667
+ """
668
+ f = self.lower(node.x)
669
+ delta = node.delta
670
+ return lambda x, u, node, params: jnp.where(
671
+ jnp.abs(f(x, u, node, params)) <= delta,
672
+ 0.5 * f(x, u, node, params) ** 2,
673
+ delta * (jnp.abs(f(x, u, node, params)) - 0.5 * delta),
674
+ )
675
+
676
+ @visitor(SmoothReLU)
677
+ def _visit_srelu(self, node):
678
+ """Lower smooth ReLU penalty function to JAX.
679
+
680
+ Smooth approximation to ReLU: sqrt(max(x, 0)^2 + c^2) - c
681
+ Differentiable everywhere, approaches ReLU as c -> 0.
682
+
683
+ Args:
684
+ node: SmoothReLU expression node with smoothing parameter c
685
+
686
+ Returns:
687
+ Function (x, u, node, params) -> smooth ReLU penalty
688
+ """
689
+ f = self.lower(node.x)
690
+ c = node.c
691
+ # smooth_relu(pos(x)) = sqrt(pos(x)^2 + c^2) - c ; here f already includes pos inside node
692
+ return (
693
+ lambda x, u, node, params: jnp.sqrt(jnp.maximum(f(x, u, node, params), 0.0) ** 2 + c**2)
694
+ - c
695
+ )
696
+
697
+ @visitor(NodalConstraint)
698
+ def _visit_nodal_constraint(self, node: NodalConstraint):
699
+ """Lower a NodalConstraint by lowering its underlying constraint.
700
+
701
+ NodalConstraint is a wrapper that specifies which nodes a constraint
702
+ applies to. The lowering just unwraps and lowers the inner constraint.
703
+
704
+ Args:
705
+ node: NodalConstraint wrapper
706
+
707
+ Returns:
708
+ Function from lowering the wrapped constraint expression
709
+ """
710
+ return self.lower(node.constraint)
711
+
712
+ @visitor(Sqrt)
713
+ def _visit_sqrt(self, node: Sqrt):
714
+ """Lower square root to JAX function."""
715
+ f = self.lower(node.operand)
716
+ return lambda x, u, node, params: jnp.sqrt(f(x, u, node, params))
717
+
718
+ @visitor(Max)
719
+ def _visit_max(self, node: Max):
720
+ """Lower element-wise maximum to JAX function."""
721
+ fs = [self.lower(op) for op in node.operands]
722
+
723
+ def fn(x, u, node, params):
724
+ values = [f(x, u, node, params) for f in fs]
725
+ # jnp.maximum can take multiple arguments
726
+ result = values[0]
727
+ for val in values[1:]:
728
+ result = jnp.maximum(result, val)
729
+ return result
730
+
731
+ return fn
732
+
733
+ @visitor(LogSumExp)
734
+ def _visit_logsumexp(self, node: LogSumExp):
735
+ """Lower log-sum-exp to JAX function.
736
+
737
+ Computes log(sum(exp(x_i))) for multiple operands, which is a smooth
738
+ approximation to the maximum function. Uses JAX's numerically stable
739
+ logsumexp implementation. Performs element-wise log-sum-exp with
740
+ broadcasting support.
741
+ """
742
+ fs = [self.lower(op) for op in node.operands]
743
+
744
+ def fn(x, u, node, params):
745
+ values = [f(x, u, node, params) for f in fs]
746
+ # Broadcast all values to the same shape, then stack along new axis
747
+ # and compute logsumexp along that axis for element-wise operation
748
+ broadcasted = jnp.broadcast_arrays(*values)
749
+ stacked = jnp.stack(list(broadcasted), axis=0)
750
+ return logsumexp(stacked, axis=0)
751
+
752
+ return fn
753
+
754
+ @visitor(Transpose)
755
+ def _visit_transpose(self, node: Transpose):
756
+ """Lower matrix transpose to JAX function."""
757
+ f = self.lower(node.operand)
758
+ return lambda x, u, node, params: jnp.transpose(f(x, u, node, params))
759
+
760
+ @visitor(Power)
761
+ def _visit_power(self, node: Power):
762
+ """Lower element-wise power (base**exponent) to JAX function."""
763
+ fB = self.lower(node.base)
764
+ fE = self.lower(node.exponent)
765
+ return lambda x, u, node, params: jnp.power(fB(x, u, node, params), fE(x, u, node, params))
766
+
767
+ @visitor(Stack)
768
+ def _visit_stack(self, node: Stack):
769
+ """Lower vertical stacking to JAX function (stack along axis 0)."""
770
+ row_fns = [self.lower(row) for row in node.rows]
771
+
772
+ def stack_fn(x, u, node, params):
773
+ rows = [jnp.atleast_1d(fn(x, u, node, params)) for fn in row_fns]
774
+ return jnp.stack(rows, axis=0)
775
+
776
+ return stack_fn
777
+
778
+ @visitor(Hstack)
779
+ def _visit_hstack(self, node: Hstack):
780
+ """Lower horizontal stacking to JAX function."""
781
+ array_fns = [self.lower(arr) for arr in node.arrays]
782
+
783
+ def hstack_fn(x, u, node, params):
784
+ arrays = [jnp.atleast_1d(fn(x, u, node, params)) for fn in array_fns]
785
+ return jnp.hstack(arrays)
786
+
787
+ return hstack_fn
788
+
789
+ @visitor(Vstack)
790
+ def _visit_vstack(self, node: Vstack):
791
+ """Lower vertical stacking to JAX function."""
792
+ array_fns = [self.lower(arr) for arr in node.arrays]
793
+
794
+ def vstack_fn(x, u, node, params):
795
+ arrays = [jnp.atleast_1d(fn(x, u, node, params)) for fn in array_fns]
796
+ return jnp.vstack(arrays)
797
+
798
+ return vstack_fn
799
+
800
+ @visitor(Block)
801
+ def _visit_block(self, node: Block):
802
+ """Lower block matrix construction to JAX function.
803
+
804
+ Assembles a block matrix from nested lists of expressions. For 2D blocks,
805
+ uses jnp.block directly. For N-D blocks (3D+), manually assembles along
806
+ the first two dimensions using concatenate, since jnp.block concatenates
807
+ along the last axes (not what we want for block matrix semantics).
808
+
809
+ Args:
810
+ node: Block expression node with 2D nested structure of expressions
811
+
812
+ Returns:
813
+ Function (x, u, node, params) -> assembled block matrix/tensor
814
+ """
815
+ # Lower each block expression
816
+ block_fns = [[self.lower(block) for block in row] for row in node.blocks]
817
+
818
+ def block_fn(x, u, node_arg, params):
819
+ # Evaluate all blocks
820
+ block_values = [
821
+ [jnp.atleast_1d(fn(x, u, node_arg, params)) for fn in row] for row in block_fns
822
+ ]
823
+
824
+ # Check if any block is 3D+ (need manual assembly)
825
+ max_ndim = max(arr.ndim for row in block_values for arr in row)
826
+
827
+ if max_ndim <= 2:
828
+ # For 2D, jnp.block works correctly
829
+ return jnp.block(block_values)
830
+ else:
831
+ # For N-D, manually assemble along axes 0 and 1
832
+ # First, ensure all blocks have the same number of dimensions
833
+ def promote_to_ndim(arr, target_ndim):
834
+ while arr.ndim < target_ndim:
835
+ arr = jnp.expand_dims(arr, axis=0)
836
+ return arr
837
+
838
+ block_values = [
839
+ [promote_to_ndim(arr, max_ndim) for arr in row] for row in block_values
840
+ ]
841
+
842
+ # Concatenate each row along axis 1 (horizontal)
843
+ row_results = [jnp.concatenate(row, axis=1) for row in block_values]
844
+ # Concatenate rows along axis 0 (vertical)
845
+ return jnp.concatenate(row_results, axis=0)
846
+
847
+ return block_fn
848
+
849
+ @visitor(QDCM)
850
+ def _visit_qdcm(self, node: QDCM):
851
+ """Lower quaternion to direction cosine matrix (DCM) conversion.
852
+
853
+ Converts a unit quaternion [q0, q1, q2, q3] to a 3x3 rotation matrix.
854
+ Used in 6-DOF spacecraft and robotics applications.
855
+
856
+ The quaternion is normalized before conversion to ensure a valid rotation
857
+ matrix. The DCM is computed using the standard quaternion-to-DCM formula.
858
+
859
+ Args:
860
+ node: QDCM expression node
861
+
862
+ Returns:
863
+ Function (x, u, node, params) -> 3x3 rotation matrix
864
+
865
+ Note:
866
+ Quaternion convention: [w, x, y, z] where w is the scalar part
867
+ """
868
+ f = self.lower(node.q)
869
+
870
+ def qdcm_fn(x, u, node, params):
871
+ q = f(x, u, node, params)
872
+ # Normalize the quaternion
873
+ q_norm = jnp.sqrt(q[0] ** 2 + q[1] ** 2 + q[2] ** 2 + q[3] ** 2)
874
+ w, qx, qy, qz = q / q_norm
875
+ # Convert to direction cosine matrix
876
+ return jnp.array(
877
+ [
878
+ [1 - 2 * (qy**2 + qz**2), 2 * (qx * qy - qz * w), 2 * (qx * qz + qy * w)],
879
+ [2 * (qx * qy + qz * w), 1 - 2 * (qx**2 + qz**2), 2 * (qy * qz - qx * w)],
880
+ [2 * (qx * qz - qy * w), 2 * (qy * qz + qx * w), 1 - 2 * (qx**2 + qy**2)],
881
+ ]
882
+ )
883
+
884
+ return qdcm_fn
885
+
886
+ @visitor(SSMP)
887
+ def _visit_ssmp(self, node: SSMP):
888
+ """Lower skew-symmetric matrix for quaternion dynamics (4x4).
889
+
890
+ Creates a 4x4 skew-symmetric matrix from angular velocity vector for
891
+ quaternion kinematic propagation: q_dot = 0.5 * SSMP(omega) @ q
892
+
893
+ The SSMP matrix is used in quaternion kinematics to compute quaternion
894
+ derivatives from angular velocity vectors.
895
+
896
+ Args:
897
+ node: SSMP expression node
898
+
899
+ Returns:
900
+ Function (x, u, node, params) -> 4x4 skew-symmetric matrix
901
+
902
+ Note:
903
+ For angular velocity w = [x, y, z], returns:
904
+ [[0, -x, -y, -z],
905
+ [x, 0, z, -y],
906
+ [y, -z, 0, x],
907
+ [z, y, -x, 0]]
908
+ """
909
+ f = self.lower(node.w)
910
+
911
+ def ssmp_fn(x, u, node, params):
912
+ w = f(x, u, node, params)
913
+ wx, wy, wz = w[0], w[1], w[2]
914
+ return jnp.array(
915
+ [
916
+ [0, -wx, -wy, -wz],
917
+ [wx, 0, wz, -wy],
918
+ [wy, -wz, 0, wx],
919
+ [wz, wy, -wx, 0],
920
+ ]
921
+ )
922
+
923
+ return ssmp_fn
924
+
925
+ @visitor(SSM)
926
+ def _visit_ssm(self, node: SSM):
927
+ """Lower skew-symmetric matrix for cross product (3x3).
928
+
929
+ Creates a 3x3 skew-symmetric matrix from a vector such that
930
+ SSM(a) @ b = a x b (cross product).
931
+
932
+ The SSM is the matrix representation of the cross product operator,
933
+ allowing cross products to be computed as matrix-vector multiplication.
934
+
935
+ Args:
936
+ node: SSM expression node
937
+
938
+ Returns:
939
+ Function (x, u, node, params) -> 3x3 skew-symmetric matrix
940
+
941
+ Note:
942
+ For vector w = [x, y, z], returns:
943
+ [[ 0, -z, y],
944
+ [ z, 0, -x],
945
+ [-y, x, 0]]
946
+ """
947
+ f = self.lower(node.w)
948
+
949
+ def ssm_fn(x, u, node, params):
950
+ w = f(x, u, node, params)
951
+ wx, wy, wz = w[0], w[1], w[2]
952
+ return jnp.array([[0, -wz, wy], [wz, 0, -wx], [-wy, wx, 0]])
953
+
954
+ return ssm_fn
955
+
956
+ @visitor(AdjointDual)
957
+ def _visit_adjoint_dual(self, node: AdjointDual):
958
+ """Lower coadjoint operator ad* for rigid body dynamics.
959
+
960
+ Computes the coadjoint action ad*_ξ(μ) which represents Coriolis and
961
+ centrifugal forces in rigid body dynamics. This is the key term in
962
+ Newton-Euler equations.
963
+
964
+ For se(3), given twist ξ = [v; ω] and momentum μ = [f; τ]:
965
+
966
+ ad*_ξ(μ) = [ ω × f + v × τ ]
967
+ [ ω × τ ]
968
+
969
+ This appears in the equations of motion as:
970
+ M @ ξ_dot = F_ext - ad*_ξ(M @ ξ)
971
+
972
+ Args:
973
+ node: AdjointDual expression node
974
+
975
+ Returns:
976
+ Function (x, u, node, params) -> 6D coadjoint result
977
+
978
+ Note:
979
+ Convention: twist = [v; ω] (linear velocity, angular velocity)
980
+ momentum = [f; τ] (force, torque)
981
+ """
982
+ f_twist = self.lower(node.twist)
983
+ f_momentum = self.lower(node.momentum)
984
+
985
+ def adjoint_dual_fn(x, u, node, params):
986
+ twist = f_twist(x, u, node, params)
987
+ momentum = f_momentum(x, u, node, params)
988
+
989
+ # Extract components: twist = [v; ω], momentum = [f; τ]
990
+ v = twist[:3] # Linear velocity
991
+ omega = twist[3:] # Angular velocity
992
+ f = momentum[:3] # Force (or linear momentum)
993
+ tau = momentum[3:] # Torque (or angular momentum)
994
+
995
+ # Coadjoint action: ad*_ξ(μ) = [ω × f + v × τ; ω × τ]
996
+ linear_part = jnp.cross(omega, f) + jnp.cross(v, tau)
997
+ angular_part = jnp.cross(omega, tau)
998
+
999
+ return jnp.concatenate([linear_part, angular_part])
1000
+
1001
+ return adjoint_dual_fn
1002
+
1003
+ @visitor(Adjoint)
1004
+ def _visit_adjoint(self, node: Adjoint):
1005
+ """Lower adjoint operator ad (Lie bracket) for twist-on-twist action.
1006
+
1007
+ Computes the adjoint action ad_ξ₁(ξ₂) which represents the Lie bracket
1008
+ [ξ₁, ξ₂] of two twists. Used for velocity propagation and acceleration
1009
+ computation in kinematic chains.
1010
+
1011
+ For se(3), given twists ξ₁ = [v₁; ω₁] and ξ₂ = [v₂; ω₂]:
1012
+
1013
+ ad_ξ₁(ξ₂) = [ ω₁ × v₂ - ω₂ × v₁ ]
1014
+ [ ω₁ × ω₂ ]
1015
+
1016
+ Args:
1017
+ node: Adjoint expression node
1018
+
1019
+ Returns:
1020
+ Function (x, u, node, params) -> 6D Lie bracket result
1021
+
1022
+ Note:
1023
+ The Lie bracket is antisymmetric: [ξ₁, ξ₂] = -[ξ₂, ξ₁]
1024
+ """
1025
+ f_twist1 = self.lower(node.twist1)
1026
+ f_twist2 = self.lower(node.twist2)
1027
+
1028
+ def adjoint_fn(x, u, node, params):
1029
+ twist1 = f_twist1(x, u, node, params)
1030
+ twist2 = f_twist2(x, u, node, params)
1031
+
1032
+ # Extract components: twist = [v; ω]
1033
+ v1 = twist1[:3]
1034
+ omega1 = twist1[3:]
1035
+ v2 = twist2[:3]
1036
+ omega2 = twist2[3:]
1037
+
1038
+ # Lie bracket: [ξ₁, ξ₂] = [ω₁ × v₂ - ω₂ × v₁; ω₁ × ω₂]
1039
+ linear_part = jnp.cross(omega1, v2) - jnp.cross(omega2, v1)
1040
+ angular_part = jnp.cross(omega1, omega2)
1041
+
1042
+ return jnp.concatenate([linear_part, angular_part])
1043
+
1044
+ return adjoint_fn
1045
+
1046
+ @visitor(SE3Adjoint)
1047
+ def _visit_se3_adjoint(self, node: SE3Adjoint):
1048
+ """Lower SE3 Adjoint (big Ad) for transforming twists between frames.
1049
+
1050
+ Computes the 6×6 adjoint matrix Ad_T that transforms twists:
1051
+ ξ_b = Ad_{T_ab} @ ξ_a
1052
+
1053
+ For SE(3) with rotation R and translation p:
1054
+ Ad_T = [ R 0 ]
1055
+ [ [p]×R R ]
1056
+
1057
+ Args:
1058
+ node: SE3Adjoint expression node
1059
+
1060
+ Returns:
1061
+ Function (x, u, node, params) -> 6×6 adjoint matrix
1062
+ """
1063
+ f_transform = self.lower(node.transform)
1064
+
1065
+ def se3_adjoint_fn(x, u, node, params):
1066
+ T = f_transform(x, u, node, params)
1067
+
1068
+ # Extract rotation and translation from 4×4 homogeneous matrix
1069
+ R = T[:3, :3]
1070
+ p = T[:3, 3]
1071
+
1072
+ # Build skew-symmetric matrix [p]×
1073
+ p_skew = jnp.array([[0, -p[2], p[1]], [p[2], 0, -p[0]], [-p[1], p[0], 0]])
1074
+
1075
+ # Build 6×6 adjoint matrix
1076
+ # Ad_T = [ R 0 ]
1077
+ # [ [p]×R R ]
1078
+ top_row = jnp.hstack([R, jnp.zeros((3, 3))])
1079
+ bottom_row = jnp.hstack([p_skew @ R, R])
1080
+
1081
+ return jnp.vstack([top_row, bottom_row])
1082
+
1083
+ return se3_adjoint_fn
1084
+
1085
+ @visitor(SE3AdjointDual)
1086
+ def _visit_se3_adjoint_dual(self, node: SE3AdjointDual):
1087
+ """Lower SE3 coadjoint (big Ad*) for transforming wrenches between frames.
1088
+
1089
+ Computes the 6×6 coadjoint matrix Ad*_T that transforms wrenches:
1090
+ F_a = Ad*_{T_ab} @ F_b
1091
+
1092
+ For SE(3) with rotation R and translation p:
1093
+ Ad*_T = [ R [p]×R ]
1094
+ [ 0 R ]
1095
+
1096
+ This is the transpose-inverse of Ad_T.
1097
+
1098
+ Args:
1099
+ node: SE3AdjointDual expression node
1100
+
1101
+ Returns:
1102
+ Function (x, u, node, params) -> 6×6 coadjoint matrix
1103
+ """
1104
+ f_transform = self.lower(node.transform)
1105
+
1106
+ def se3_adjoint_dual_fn(x, u, node, params):
1107
+ T = f_transform(x, u, node, params)
1108
+
1109
+ # Extract rotation and translation from 4×4 homogeneous matrix
1110
+ R = T[:3, :3]
1111
+ p = T[:3, 3]
1112
+
1113
+ # Build skew-symmetric matrix [p]×
1114
+ p_skew = jnp.array([[0, -p[2], p[1]], [p[2], 0, -p[0]], [-p[1], p[0], 0]])
1115
+
1116
+ # Build 6×6 coadjoint matrix
1117
+ # Ad*_T = [ R [p]×R ]
1118
+ # [ 0 R ]
1119
+ top_row = jnp.hstack([R, p_skew @ R])
1120
+ bottom_row = jnp.hstack([jnp.zeros((3, 3)), R])
1121
+
1122
+ return jnp.vstack([top_row, bottom_row])
1123
+
1124
+ return se3_adjoint_dual_fn
1125
+
1126
+ @visitor(SO3Exp)
1127
+ def _visit_so3_exp(self, node: SO3Exp):
1128
+ """Lower SO3 exponential map using jaxlie.
1129
+
1130
+ Maps a 3D rotation vector (axis-angle) to a 3×3 rotation matrix
1131
+ using jaxlie's numerically robust implementation.
1132
+
1133
+ Args:
1134
+ node: SO3Exp expression node
1135
+
1136
+ Returns:
1137
+ Function (x, u, node, params) -> 3×3 rotation matrix
1138
+ """
1139
+ import jaxlie
1140
+
1141
+ f_omega = self.lower(node.omega)
1142
+
1143
+ def so3_exp_fn(x, u, node, params):
1144
+ omega = f_omega(x, u, node, params)
1145
+ return jaxlie.SO3.exp(omega).as_matrix()
1146
+
1147
+ return so3_exp_fn
1148
+
1149
+ @visitor(SO3Log)
1150
+ def _visit_so3_log(self, node: SO3Log):
1151
+ """Lower SO3 logarithm map using jaxlie.
1152
+
1153
+ Maps a 3×3 rotation matrix to a 3D rotation vector (axis-angle)
1154
+ using jaxlie's numerically robust implementation.
1155
+
1156
+ Args:
1157
+ node: SO3Log expression node
1158
+
1159
+ Returns:
1160
+ Function (x, u, node, params) -> 3D rotation vector
1161
+ """
1162
+ import jaxlie
1163
+
1164
+ f_rotation = self.lower(node.rotation)
1165
+
1166
+ def so3_log_fn(x, u, node, params):
1167
+ rotation = f_rotation(x, u, node, params)
1168
+ return jaxlie.SO3.from_matrix(rotation).log()
1169
+
1170
+ return so3_log_fn
1171
+
1172
+ @visitor(SE3Exp)
1173
+ def _visit_se3_exp(self, node: SE3Exp):
1174
+ """Lower SE3 exponential map using jaxlie.
1175
+
1176
+ Maps a 6D twist vector [v; ω] to a 4×4 homogeneous transformation
1177
+ matrix using jaxlie's numerically robust implementation.
1178
+
1179
+ The twist convention [v; ω] (linear first, angular second) matches
1180
+ jaxlie's SE3 tangent parameterization, so no reordering is needed.
1181
+
1182
+ Args:
1183
+ node: SE3Exp expression node
1184
+
1185
+ Returns:
1186
+ Function (x, u, node, params) -> 4×4 transformation matrix
1187
+ """
1188
+ import jaxlie
1189
+
1190
+ f_twist = self.lower(node.twist)
1191
+
1192
+ def se3_exp_fn(x, u, node, params):
1193
+ twist = f_twist(x, u, node, params)
1194
+ return jaxlie.SE3.exp(twist).as_matrix()
1195
+
1196
+ return se3_exp_fn
1197
+
1198
+ @visitor(SE3Log)
1199
+ def _visit_se3_log(self, node: SE3Log):
1200
+ """Lower SE3 logarithm map using jaxlie.
1201
+
1202
+ Maps a 4×4 homogeneous transformation matrix to a 6D twist vector
1203
+ [v; ω] using jaxlie's numerically robust implementation.
1204
+
1205
+ Args:
1206
+ node: SE3Log expression node
1207
+
1208
+ Returns:
1209
+ Function (x, u, node, params) -> 6D twist vector
1210
+ """
1211
+ import jaxlie
1212
+
1213
+ f_transform = self.lower(node.transform)
1214
+
1215
+ def se3_log_fn(x, u, node, params):
1216
+ transform = f_transform(x, u, node, params)
1217
+ return jaxlie.SE3.from_matrix(transform).log()
1218
+
1219
+ return se3_log_fn
1220
+
1221
+ @visitor(Diag)
1222
+ def _visit_diag(self, node: Diag):
1223
+ """Lower diagonal matrix construction to JAX function."""
1224
+ f = self.lower(node.operand)
1225
+ return lambda x, u, node, params: jnp.diag(f(x, u, node, params))
1226
+
1227
+ @visitor(Or)
1228
+ def _visit_or(self, node: Or):
1229
+ """Lower STL disjunction (Or) to JAX using STLJax library.
1230
+
1231
+ Converts a symbolic Or constraint to an STLJax Or formula for handling
1232
+ disjunctive task specifications. Each operand becomes an STLJax predicate.
1233
+
1234
+ Args:
1235
+ node: Or expression node with multiple operands
1236
+
1237
+ Returns:
1238
+ Function (x, u, node, params) -> STL robustness value
1239
+
1240
+ Note:
1241
+ Uses STLJax library for signal temporal logic evaluation. The returned
1242
+ function computes the robustness metric for the disjunction, which is
1243
+ positive when at least one operand is satisfied.
1244
+
1245
+ Example:
1246
+ Used for task specifications like "reach goal A OR goal B"::
1247
+
1248
+ goal_A = ox.Norm(x - target_A) <= 1.0
1249
+ goal_B = ox.Norm(x - target_B) <= 1.0
1250
+ task = ox.Or(goal_A, goal_B)
1251
+
1252
+ See Also:
1253
+ - stljax.formula.Or: Underlying STLJax implementation
1254
+ - STL robustness: Quantitative measure of constraint satisfaction
1255
+ """
1256
+ from stljax.formula import Or as STLOr
1257
+ from stljax.formula import Predicate
1258
+
1259
+ # Lower each operand to get their functions
1260
+ operand_fns = [self.lower(operand) for operand in node.operands]
1261
+
1262
+ # Return a function that evaluates the STLJax Or
1263
+ def or_fn(x, u, node, params):
1264
+ # Create STLJax predicates for each operand with current params
1265
+ predicates = []
1266
+ for i, operand_fn in enumerate(operand_fns):
1267
+ # Create a predicate function that captures the current params
1268
+ def make_pred_fn(fn):
1269
+ return lambda x: fn(x, None, None, params)
1270
+
1271
+ pred_fn = make_pred_fn(operand_fn)
1272
+ predicates.append(Predicate(f"pred_{i}", pred_fn))
1273
+
1274
+ # Create and evaluate STLJax Or formula
1275
+ stl_or = STLOr(*predicates)
1276
+ return stl_or(x)
1277
+
1278
+ return or_fn
1279
+
1280
+ @visitor(NodeReference)
1281
+ def _visit_node_reference(self, node: NodeReference):
1282
+ """Lower NodeReference - extract value at a specific trajectory node.
1283
+
1284
+ NodeReference extracts a state/control value at a specific node from the
1285
+ full trajectory arrays. The node index is baked into the lowered function.
1286
+
1287
+ Args:
1288
+ node: NodeReference expression with base and node_idx (integer)
1289
+
1290
+ Returns:
1291
+ Function (x, u, node_param, params) that extracts from trajectory
1292
+ - x, u: Full trajectories (N, n_x) and (N, n_u)
1293
+ - node_param: Unused (kept for signature compatibility)
1294
+ - params: Problem parameters
1295
+
1296
+ Example:
1297
+ position.at(5) lowers to a function that extracts x[5, position_slice]
1298
+ position.at(k-1) where k=7 lowers to extract x[6, position_slice]
1299
+ """
1300
+ from openscvx.symbolic.expr.control import Control
1301
+ from openscvx.symbolic.expr.state import State
1302
+
1303
+ # Node index is baked into the expression at construction time
1304
+ fixed_idx = node.node_idx
1305
+
1306
+ if isinstance(node.base, State):
1307
+ sl = node.base._slice
1308
+ if sl is None:
1309
+ raise ValueError(f"State {node.base.name!r} has no slice assigned")
1310
+
1311
+ def state_node_fn(x, u, node_param, params):
1312
+ return x[fixed_idx, sl]
1313
+
1314
+ return state_node_fn
1315
+
1316
+ elif isinstance(node.base, Control):
1317
+ sl = node.base._slice
1318
+ if sl is None:
1319
+ raise ValueError(f"Control {node.base.name!r} has no slice assigned")
1320
+
1321
+ def control_node_fn(x, u, node_param, params):
1322
+ return u[fixed_idx, sl]
1323
+
1324
+ return control_node_fn
1325
+
1326
+ else:
1327
+ # Compound expression (e.g., position[0].at(5))
1328
+ base_fn = self.lower(node.base)
1329
+
1330
+ def compound_node_fn(x, u, node_param, params):
1331
+ # Extract single-node slices and evaluate base expression
1332
+ x_single = x[fixed_idx] if len(x.shape) > 1 else x
1333
+ u_single = u[fixed_idx] if len(u.shape) > 1 else u
1334
+ return base_fn(x_single, u_single, fixed_idx, params)
1335
+
1336
+ return compound_node_fn
1337
+
1338
+ @visitor(CrossNodeConstraint)
1339
+ def _visit_cross_node_constraint(self, node: CrossNodeConstraint):
1340
+ """Lower CrossNodeConstraint to trajectory-level function.
1341
+
1342
+ CrossNodeConstraint wraps constraints that reference multiple trajectory
1343
+ nodes via NodeReference (e.g., rate limits like x.at(k) - x.at(k-1) <= r).
1344
+
1345
+ Unlike regular nodal constraints which have signature (x, u, node, params)
1346
+ and are vmapped across nodes, cross-node constraints operate on full
1347
+ trajectory arrays and return a scalar residual.
1348
+
1349
+ Args:
1350
+ node: CrossNodeConstraint expression wrapping the inner constraint
1351
+
1352
+ Returns:
1353
+ Function with signature (X, U, params) -> scalar residual
1354
+ - X: Full state trajectory, shape (N, n_x)
1355
+ - U: Full control trajectory, shape (N, n_u)
1356
+ - params: Dictionary of problem parameters
1357
+ - Returns: Scalar constraint residual (g <= 0 convention)
1358
+
1359
+ Note:
1360
+ The inner constraint is lowered first (producing a function with the
1361
+ standard (x, u, node, params) signature), then wrapped to provide the
1362
+ trajectory-level (X, U, params) signature. The `node` parameter is
1363
+ unused since NodeReference nodes have fixed indices baked in.
1364
+
1365
+ Example:
1366
+ For constraint: position.at(5) - position.at(4) <= max_step
1367
+
1368
+ The lowered function evaluates:
1369
+ X[5, pos_slice] - X[4, pos_slice] - max_step
1370
+
1371
+ And returns a scalar residual.
1372
+ """
1373
+ # Lower the inner constraint expression
1374
+ inner_fn = self.lower(node.constraint)
1375
+
1376
+ # Wrap to provide trajectory-level signature
1377
+ # The `node` parameter is unused for cross-node constraints since
1378
+ # NodeReference nodes have fixed indices baked in at construction time
1379
+ def trajectory_constraint(X, U, params):
1380
+ return inner_fn(X, U, 0, params)
1381
+
1382
+ return trajectory_constraint