openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1382 @@
|
|
|
1
|
+
"""JAX backend for lowering symbolic expressions to executable functions.
|
|
2
|
+
|
|
3
|
+
This module implements the JAX lowering backend that converts symbolic expression
|
|
4
|
+
AST nodes into JAX functions with automatic differentiation support. The lowering
|
|
5
|
+
uses a visitor pattern where each expression type has a corresponding visitor method.
|
|
6
|
+
|
|
7
|
+
Architecture:
|
|
8
|
+
The JAX lowerer follows a visitor pattern with centralized registration:
|
|
9
|
+
|
|
10
|
+
1. **Visitor Registration**: The @visitor decorator registers handler functions
|
|
11
|
+
for each expression type in the _JAX_VISITORS dictionary
|
|
12
|
+
2. **Dispatch**: The dispatch() function looks up and calls the appropriate
|
|
13
|
+
visitor based on the expression's type
|
|
14
|
+
3. **Recursive Lowering**: Each visitor recursively lowers child expressions
|
|
15
|
+
and composes JAX operations
|
|
16
|
+
4. **Standardized Signature**: All lowered functions have signature
|
|
17
|
+
(x, u, node, params) -> result for uniformity
|
|
18
|
+
|
|
19
|
+
Key Features:
|
|
20
|
+
- **Automatic Differentiation**: Lowered functions can be differentiated using
|
|
21
|
+
JAX's jacfwd/jacrev for computing Jacobians
|
|
22
|
+
- **JIT Compilation**: All functions are JAX-traceable and JIT-compatible
|
|
23
|
+
- **Functional Closures**: Each visitor returns a closure that captures
|
|
24
|
+
necessary constants and child functions
|
|
25
|
+
- **Broadcasting**: Supports NumPy-style broadcasting through jnp operations
|
|
26
|
+
|
|
27
|
+
Lowered Function Signature:
|
|
28
|
+
All lowered functions have a uniform signature::
|
|
29
|
+
|
|
30
|
+
f(x, u, node, params) -> result
|
|
31
|
+
|
|
32
|
+
Where:
|
|
33
|
+
|
|
34
|
+
- x: State vector (jnp.ndarray)
|
|
35
|
+
- u: Control vector (jnp.ndarray)
|
|
36
|
+
- node: Node index for time-varying behavior (scalar or array)
|
|
37
|
+
- params: Dictionary of parameter values (dict[str, Any])
|
|
38
|
+
- result: JAX array (scalar, vector, or matrix)
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
Basic usage::
|
|
42
|
+
|
|
43
|
+
from openscvx.symbolic.lowerers.jax import JaxLowerer
|
|
44
|
+
import openscvx as ox
|
|
45
|
+
|
|
46
|
+
# Create symbolic expression
|
|
47
|
+
x = ox.State("x", shape=(3,))
|
|
48
|
+
u = ox.Control("u", shape=(2,))
|
|
49
|
+
expr = ox.Norm(x)**2 + 0.1 * ox.Norm(u)**2
|
|
50
|
+
|
|
51
|
+
# Lower to JAX
|
|
52
|
+
lowerer = JaxLowerer()
|
|
53
|
+
f = lowerer.lower(expr)
|
|
54
|
+
|
|
55
|
+
# Evaluate
|
|
56
|
+
import jax.numpy as jnp
|
|
57
|
+
x_val = jnp.array([1.0, 2.0, 3.0])
|
|
58
|
+
u_val = jnp.array([0.5, 0.5])
|
|
59
|
+
result = f(x_val, u_val, node=0, params={})
|
|
60
|
+
|
|
61
|
+
# Differentiate
|
|
62
|
+
from jax import jacfwd
|
|
63
|
+
df_dx = jacfwd(f, argnums=0)
|
|
64
|
+
gradient = df_dx(x_val, u_val, node=0, params={})
|
|
65
|
+
|
|
66
|
+
For Contributors:
|
|
67
|
+
**Adding Support for New Expression Types**
|
|
68
|
+
|
|
69
|
+
To add support for a new symbolic expression type to JAX lowering:
|
|
70
|
+
|
|
71
|
+
1. **Define the visitor method** in JaxLowerer with the @visitor decorator::
|
|
72
|
+
|
|
73
|
+
@visitor(MyNewExpr)
|
|
74
|
+
def _visit_my_new_expr(self, node: MyNewExpr):
|
|
75
|
+
# Lower child expressions recursively
|
|
76
|
+
operand_fn = self.lower(node.operand)
|
|
77
|
+
|
|
78
|
+
# Return a closure with signature (x, u, node, params) -> result
|
|
79
|
+
return lambda x, u, node, params: jnp.my_operation(
|
|
80
|
+
operand_fn(x, u, node, params)
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
2. **Key requirements**:
|
|
84
|
+
- Use the @visitor(ExprType) decorator to register the handler
|
|
85
|
+
- Method name should be _visit_<expr_name> (private, lowercase, snake_case)
|
|
86
|
+
- Recursively lower all child expressions using self.lower()
|
|
87
|
+
- Return a closure with signature (x, u, node, params) -> jax_array
|
|
88
|
+
- Use jnp.* operations (not np.*) for JAX traceability
|
|
89
|
+
- Ensure the result is JAX-differentiable (avoid Python control flow)
|
|
90
|
+
|
|
91
|
+
3. **Example patterns**:
|
|
92
|
+
- Unary operation: Lower operand, apply jnp function
|
|
93
|
+
- Binary operation: Lower both operands, combine with jnp operation
|
|
94
|
+
- N-ary operation: Lower all operands, reduce or combine them
|
|
95
|
+
- Conditional logic: Use jax.lax.cond for branching (see _visit_ctcs)
|
|
96
|
+
|
|
97
|
+
4. **Testing**: Ensure your visitor works with:
|
|
98
|
+
- JAX JIT compilation: jax.jit(lowered_fn)
|
|
99
|
+
- Automatic differentiation: jax.jacfwd(lowered_fn, argnums=0)
|
|
100
|
+
- Vectorization: jax.vmap(lowered_fn)
|
|
101
|
+
|
|
102
|
+
See Also:
|
|
103
|
+
- lower_to_jax(): Convenience wrapper in symbolic/lower.py
|
|
104
|
+
- CVXPyLowerer: Alternative backend for convex constraints
|
|
105
|
+
- dispatch(): Core dispatch function for visitor pattern
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
from typing import Any, Callable, Dict, Type
|
|
109
|
+
|
|
110
|
+
import jax.numpy as jnp
|
|
111
|
+
from jax.lax import cond
|
|
112
|
+
from jax.scipy.special import logsumexp
|
|
113
|
+
|
|
114
|
+
from openscvx.symbolic.expr import (
|
|
115
|
+
CTCS,
|
|
116
|
+
QDCM,
|
|
117
|
+
SSM,
|
|
118
|
+
SSMP,
|
|
119
|
+
Abs,
|
|
120
|
+
Add,
|
|
121
|
+
Adjoint,
|
|
122
|
+
AdjointDual,
|
|
123
|
+
Block,
|
|
124
|
+
Concat,
|
|
125
|
+
Constant,
|
|
126
|
+
Constraint,
|
|
127
|
+
Cos,
|
|
128
|
+
CrossNodeConstraint,
|
|
129
|
+
Diag,
|
|
130
|
+
Div,
|
|
131
|
+
Equality,
|
|
132
|
+
Exp,
|
|
133
|
+
Expr,
|
|
134
|
+
Hstack,
|
|
135
|
+
Huber,
|
|
136
|
+
Index,
|
|
137
|
+
Inequality,
|
|
138
|
+
Log,
|
|
139
|
+
LogSumExp,
|
|
140
|
+
MatMul,
|
|
141
|
+
Max,
|
|
142
|
+
Mul,
|
|
143
|
+
Neg,
|
|
144
|
+
NodalConstraint,
|
|
145
|
+
NodeReference,
|
|
146
|
+
Norm,
|
|
147
|
+
Or,
|
|
148
|
+
Parameter,
|
|
149
|
+
PositivePart,
|
|
150
|
+
Power,
|
|
151
|
+
SE3Adjoint,
|
|
152
|
+
SE3AdjointDual,
|
|
153
|
+
Sin,
|
|
154
|
+
SmoothReLU,
|
|
155
|
+
Sqrt,
|
|
156
|
+
Square,
|
|
157
|
+
Stack,
|
|
158
|
+
Sub,
|
|
159
|
+
Sum,
|
|
160
|
+
Tan,
|
|
161
|
+
Transpose,
|
|
162
|
+
Vstack,
|
|
163
|
+
)
|
|
164
|
+
from openscvx.symbolic.expr.control import Control
|
|
165
|
+
from openscvx.symbolic.expr.lie import (
|
|
166
|
+
SE3Exp,
|
|
167
|
+
SE3Log,
|
|
168
|
+
SO3Exp,
|
|
169
|
+
SO3Log,
|
|
170
|
+
)
|
|
171
|
+
from openscvx.symbolic.expr.state import State
|
|
172
|
+
|
|
173
|
+
_JAX_VISITORS: Dict[Type[Expr], Callable] = {}
|
|
174
|
+
"""Registry mapping expression types to their visitor functions."""
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def visitor(expr_cls: Type[Expr]):
|
|
178
|
+
"""Decorator to register a visitor function for an expression type.
|
|
179
|
+
|
|
180
|
+
This decorator registers a visitor method to handle a specific expression
|
|
181
|
+
type during JAX lowering. The decorated function is stored in _JAX_VISITORS
|
|
182
|
+
and will be called by dispatch() when lowering that expression type.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
expr_cls: The Expr subclass this visitor handles (e.g., Add, Mul, Norm)
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
Decorator function that registers the visitor and returns it unchanged
|
|
189
|
+
|
|
190
|
+
Example:
|
|
191
|
+
Register a visitor function for the Add expression:
|
|
192
|
+
|
|
193
|
+
@visitor(Add)
|
|
194
|
+
def _visit_add(self, node: Add):
|
|
195
|
+
# Lower addition to JAX
|
|
196
|
+
...
|
|
197
|
+
|
|
198
|
+
Note:
|
|
199
|
+
Multiple expression types can share a visitor by stacking decorators::
|
|
200
|
+
|
|
201
|
+
@visitor(Equality)
|
|
202
|
+
@visitor(Inequality)
|
|
203
|
+
def _visit_constraint(self, node: Constraint):
|
|
204
|
+
# Handle both equality and inequality
|
|
205
|
+
...
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
def register(fn: Callable[[Any, Expr], Callable]):
|
|
209
|
+
_JAX_VISITORS[expr_cls] = fn
|
|
210
|
+
return fn
|
|
211
|
+
|
|
212
|
+
return register
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def dispatch(lowerer: Any, expr: Expr):
|
|
216
|
+
"""Dispatch an expression to its registered visitor function.
|
|
217
|
+
|
|
218
|
+
Looks up the visitor function for the expression's type and calls it.
|
|
219
|
+
This is the core of the visitor pattern implementation.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
lowerer: The JaxLowerer instance (provides context for visitor methods)
|
|
223
|
+
expr: The expression node to lower
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
The result of calling the visitor function (typically a JAX callable)
|
|
227
|
+
|
|
228
|
+
Raises:
|
|
229
|
+
NotImplementedError: If no visitor is registered for the expression type
|
|
230
|
+
|
|
231
|
+
Example:
|
|
232
|
+
Dispatch an expression to lower it to a JAX function:
|
|
233
|
+
|
|
234
|
+
lowerer = JaxLowerer()
|
|
235
|
+
expr = Add(x, y)
|
|
236
|
+
fn = dispatch(lowerer, expr) # Calls visit_add
|
|
237
|
+
"""
|
|
238
|
+
fn = _JAX_VISITORS.get(type(expr))
|
|
239
|
+
if fn is None:
|
|
240
|
+
raise NotImplementedError(
|
|
241
|
+
f"{lowerer.__class__.__name__!r} has no visitor for {type(expr).__name__}"
|
|
242
|
+
)
|
|
243
|
+
return fn(lowerer, expr)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class JaxLowerer:
|
|
247
|
+
"""JAX backend for lowering symbolic expressions to executable functions.
|
|
248
|
+
|
|
249
|
+
This class implements the visitor pattern for converting symbolic expression
|
|
250
|
+
AST nodes to JAX functions. Each expression type has a corresponding visitor
|
|
251
|
+
method decorated with @visitor that handles the lowering logic.
|
|
252
|
+
|
|
253
|
+
The lowering process is recursive: each visitor lowers its child expressions
|
|
254
|
+
first, then composes them into a JAX operation. All lowered functions have
|
|
255
|
+
a standardized signature (x, u, node, params) -> result.
|
|
256
|
+
|
|
257
|
+
Attributes:
|
|
258
|
+
None (stateless lowerer - all state is in the expression tree)
|
|
259
|
+
|
|
260
|
+
Example:
|
|
261
|
+
Set up the JaxLowerer and lower an expression to a JAX function:
|
|
262
|
+
|
|
263
|
+
lowerer = JaxLowerer()
|
|
264
|
+
expr = ox.Norm(x)**2 + 0.1 * ox.Norm(u)**2
|
|
265
|
+
f = lowerer.lower(expr)
|
|
266
|
+
result = f(x_val, u_val, node=0, params={})
|
|
267
|
+
|
|
268
|
+
Note:
|
|
269
|
+
The lowerer is stateless and can be reused for multiple expressions.
|
|
270
|
+
All visitor methods are instance methods to maintain a clean interface,
|
|
271
|
+
but they don't modify instance state.
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
def lower(self, expr: Expr):
|
|
275
|
+
"""Lower a symbolic expression to a JAX function.
|
|
276
|
+
|
|
277
|
+
Main entry point for lowering. Delegates to dispatch() which looks up
|
|
278
|
+
the appropriate visitor method based on the expression type.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
expr: Symbolic expression to lower (any Expr subclass)
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
JAX function with signature (x, u, node, params) -> result
|
|
285
|
+
|
|
286
|
+
Raises:
|
|
287
|
+
NotImplementedError: If no visitor exists for the expression type
|
|
288
|
+
ValueError: If the expression is malformed (e.g., State without slice)
|
|
289
|
+
|
|
290
|
+
Example:
|
|
291
|
+
Lower an expression to a JAX function:
|
|
292
|
+
|
|
293
|
+
lowerer = JaxLowerer()
|
|
294
|
+
x = ox.State("x", shape=(3,))
|
|
295
|
+
expr = ox.Norm(x)
|
|
296
|
+
f = lowerer.lower(expr)
|
|
297
|
+
# f is now callable
|
|
298
|
+
"""
|
|
299
|
+
return dispatch(self, expr)
|
|
300
|
+
|
|
301
|
+
@visitor(Constant)
|
|
302
|
+
def _visit_constant(self, node: Constant):
|
|
303
|
+
"""Lower a constant value to a JAX function.
|
|
304
|
+
|
|
305
|
+
Captures the constant value and returns a function that always returns it.
|
|
306
|
+
Scalar constants are squeezed to ensure they're true scalars, not (1,) arrays.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
node: Constant expression node
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
Function (x, u, node, params) -> constant_value
|
|
313
|
+
"""
|
|
314
|
+
# capture the constant value once
|
|
315
|
+
value = jnp.array(node.value)
|
|
316
|
+
# For scalar constants (single element arrays), squeeze to scalar
|
|
317
|
+
# This prevents (1,) shapes in constraint residuals
|
|
318
|
+
if value.size == 1:
|
|
319
|
+
value = value.squeeze()
|
|
320
|
+
return lambda x, u, node, params: value
|
|
321
|
+
|
|
322
|
+
@visitor(State)
|
|
323
|
+
def _visit_state(self, node: State):
|
|
324
|
+
"""Lower a state variable to a JAX function.
|
|
325
|
+
|
|
326
|
+
Extracts the appropriate slice from the unified state vector x using
|
|
327
|
+
the slice assigned during unification.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
node: State expression node
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
Function (x, u, node, params) -> x[slice]
|
|
334
|
+
|
|
335
|
+
Raises:
|
|
336
|
+
ValueError: If the state has no slice assigned (unification not run)
|
|
337
|
+
"""
|
|
338
|
+
sl = node._slice
|
|
339
|
+
if sl is None:
|
|
340
|
+
raise ValueError(f"State {node.name!r} has no slice assigned")
|
|
341
|
+
return lambda x, u, node, params: x[sl]
|
|
342
|
+
|
|
343
|
+
@visitor(Control)
|
|
344
|
+
def _visit_control(self, node: Control):
|
|
345
|
+
"""Lower a control variable to a JAX function.
|
|
346
|
+
|
|
347
|
+
Extracts the appropriate slice from the unified control vector u using
|
|
348
|
+
the slice assigned during unification.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
node: Control expression node
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
Function (x, u, node, params) -> u[slice]
|
|
355
|
+
|
|
356
|
+
Raises:
|
|
357
|
+
ValueError: If the control has no slice assigned (unification not run)
|
|
358
|
+
"""
|
|
359
|
+
sl = node._slice
|
|
360
|
+
if sl is None:
|
|
361
|
+
raise ValueError(f"Control {node.name!r} has no slice assigned")
|
|
362
|
+
return lambda x, u, node, params: u[sl]
|
|
363
|
+
|
|
364
|
+
@visitor(Parameter)
|
|
365
|
+
def _visit_parameter(self, node: Parameter):
|
|
366
|
+
"""Lower a parameter to a JAX function.
|
|
367
|
+
|
|
368
|
+
Parameters are looked up by name in the params dictionary at evaluation time,
|
|
369
|
+
allowing runtime parameter updates without recompilation.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
node: Parameter expression node
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
Function (x, u, node, params) -> params[name]
|
|
376
|
+
"""
|
|
377
|
+
param_name = node.name
|
|
378
|
+
return lambda x, u, node, params: jnp.array(params[param_name])
|
|
379
|
+
|
|
380
|
+
@visitor(Add)
|
|
381
|
+
def _visit_add(self, node: Add):
|
|
382
|
+
"""Lower addition to JAX function.
|
|
383
|
+
|
|
384
|
+
Recursively lowers all terms and composes them with element-wise addition.
|
|
385
|
+
Supports broadcasting following NumPy/JAX rules.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
node: Add expression node with multiple terms
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
Function (x, u, node, params) -> sum of all terms
|
|
392
|
+
"""
|
|
393
|
+
fs = [self.lower(term) for term in node.terms]
|
|
394
|
+
|
|
395
|
+
def fn(x, u, node, params):
|
|
396
|
+
acc = fs[0](x, u, node, params)
|
|
397
|
+
for f in fs[1:]:
|
|
398
|
+
acc = acc + f(x, u, node, params)
|
|
399
|
+
return acc
|
|
400
|
+
|
|
401
|
+
return fn
|
|
402
|
+
|
|
403
|
+
@visitor(Sub)
|
|
404
|
+
def _visit_sub(self, node: Sub):
|
|
405
|
+
"""Lower subtraction to JAX function (element-wise left - right)."""
|
|
406
|
+
fL = self.lower(node.left)
|
|
407
|
+
fR = self.lower(node.right)
|
|
408
|
+
return lambda x, u, node, params: fL(x, u, node, params) - fR(x, u, node, params)
|
|
409
|
+
|
|
410
|
+
@visitor(Mul)
|
|
411
|
+
def _visit_mul(self, node: Mul):
|
|
412
|
+
"""Lower element-wise multiplication to JAX function (Hadamard product)."""
|
|
413
|
+
fs = [self.lower(factor) for factor in node.factors]
|
|
414
|
+
|
|
415
|
+
def fn(x, u, node, params):
|
|
416
|
+
acc = fs[0](x, u, node, params)
|
|
417
|
+
for f in fs[1:]:
|
|
418
|
+
acc = acc * f(x, u, node, params)
|
|
419
|
+
return acc
|
|
420
|
+
|
|
421
|
+
return fn
|
|
422
|
+
|
|
423
|
+
@visitor(Div)
|
|
424
|
+
def _visit_div(self, node: Div):
|
|
425
|
+
"""Lower element-wise division to JAX function."""
|
|
426
|
+
fL = self.lower(node.left)
|
|
427
|
+
fR = self.lower(node.right)
|
|
428
|
+
return lambda x, u, node, params: fL(x, u, node, params) / fR(x, u, node, params)
|
|
429
|
+
|
|
430
|
+
@visitor(MatMul)
|
|
431
|
+
def _visit_matmul(self, node: MatMul):
|
|
432
|
+
"""Lower matrix multiplication to JAX function using jnp.matmul."""
|
|
433
|
+
fL = self.lower(node.left)
|
|
434
|
+
fR = self.lower(node.right)
|
|
435
|
+
return lambda x, u, node, params: jnp.matmul(fL(x, u, node, params), fR(x, u, node, params))
|
|
436
|
+
|
|
437
|
+
@visitor(Neg)
|
|
438
|
+
def _visit_neg(self, node: Neg):
|
|
439
|
+
"""Lower negation (unary minus) to JAX function."""
|
|
440
|
+
fO = self.lower(node.operand)
|
|
441
|
+
return lambda x, u, node, params: -fO(x, u, node, params)
|
|
442
|
+
|
|
443
|
+
@visitor(Sum)
|
|
444
|
+
def _visit_sum(self, node: Sum):
|
|
445
|
+
"""Lower sum reduction to JAX function (sums all elements)."""
|
|
446
|
+
f = self.lower(node.operand)
|
|
447
|
+
return lambda x, u, node, params: jnp.sum(f(x, u, node, params))
|
|
448
|
+
|
|
449
|
+
@visitor(Norm)
|
|
450
|
+
def _visit_norm(self, node: Norm):
|
|
451
|
+
"""Lower norm operation to JAX function.
|
|
452
|
+
|
|
453
|
+
Converts symbolic norm to jnp.linalg.norm with appropriate ord parameter.
|
|
454
|
+
Handles string ord values like "inf", "-inf", "fro".
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
node: Norm expression node with ord attribute
|
|
458
|
+
|
|
459
|
+
Returns:
|
|
460
|
+
Function (x, u, node, params) -> norm of operand
|
|
461
|
+
"""
|
|
462
|
+
f = self.lower(node.operand)
|
|
463
|
+
ord_val = node.ord
|
|
464
|
+
|
|
465
|
+
# Convert string ord values to appropriate JAX values
|
|
466
|
+
if ord_val == "inf":
|
|
467
|
+
ord_val = jnp.inf
|
|
468
|
+
elif ord_val == "-inf":
|
|
469
|
+
ord_val = -jnp.inf
|
|
470
|
+
elif ord_val == "fro":
|
|
471
|
+
# For vectors, Frobenius norm is the same as 2-norm
|
|
472
|
+
ord_val = None # Default is 2-norm
|
|
473
|
+
|
|
474
|
+
return lambda x, u, node, params: jnp.linalg.norm(f(x, u, node, params), ord=ord_val)
|
|
475
|
+
|
|
476
|
+
@visitor(Index)
|
|
477
|
+
def _visit_index(self, node: Index):
|
|
478
|
+
"""Lower indexing/slicing operation to JAX function.
|
|
479
|
+
|
|
480
|
+
For multi-dimensional indexing, the base array is reshaped to its
|
|
481
|
+
original shape before applying the index. This is necessary because
|
|
482
|
+
State variables are stored flattened in the state vector.
|
|
483
|
+
"""
|
|
484
|
+
f_base = self.lower(node.base)
|
|
485
|
+
idx = node.index
|
|
486
|
+
base_shape = node.base.check_shape()
|
|
487
|
+
|
|
488
|
+
def index_fn(x, u, node_arg, params):
|
|
489
|
+
arr = f_base(x, u, node_arg, params)
|
|
490
|
+
# Reshape to original shape for multi-dimensional indexing
|
|
491
|
+
if len(base_shape) > 1:
|
|
492
|
+
arr = arr.reshape(base_shape)
|
|
493
|
+
else:
|
|
494
|
+
arr = jnp.atleast_1d(arr)
|
|
495
|
+
return arr[idx]
|
|
496
|
+
|
|
497
|
+
return index_fn
|
|
498
|
+
|
|
499
|
+
@visitor(Concat)
|
|
500
|
+
def _visit_concat(self, node: Concat):
|
|
501
|
+
"""Lower concatenation to JAX function (concatenates along axis 0)."""
|
|
502
|
+
# lower each child
|
|
503
|
+
fn_list = [self.lower(child) for child in node.exprs]
|
|
504
|
+
|
|
505
|
+
# wrapper that promotes scalars to 1-D and concatenates
|
|
506
|
+
def concat_fn(x, u, node, params):
|
|
507
|
+
parts = [jnp.atleast_1d(fn(x, u, node, params)) for fn in fn_list]
|
|
508
|
+
return jnp.concatenate(parts, axis=0)
|
|
509
|
+
|
|
510
|
+
return concat_fn
|
|
511
|
+
|
|
512
|
+
@visitor(Sin)
|
|
513
|
+
def _visit_sin(self, node: Sin):
|
|
514
|
+
"""Lower sine function to JAX function."""
|
|
515
|
+
fO = self.lower(node.operand)
|
|
516
|
+
return lambda x, u, node, params: jnp.sin(fO(x, u, node, params))
|
|
517
|
+
|
|
518
|
+
@visitor(Cos)
|
|
519
|
+
def _visit_cos(self, node: Cos):
|
|
520
|
+
"""Lower cosine function to JAX function."""
|
|
521
|
+
fO = self.lower(node.operand)
|
|
522
|
+
return lambda x, u, node, params: jnp.cos(fO(x, u, node, params))
|
|
523
|
+
|
|
524
|
+
@visitor(Tan)
|
|
525
|
+
def _visit_tan(self, node: Tan):
|
|
526
|
+
"""Lower tangent function to JAX function."""
|
|
527
|
+
fO = self.lower(node.operand)
|
|
528
|
+
return lambda x, u, node, params: jnp.tan(fO(x, u, node, params))
|
|
529
|
+
|
|
530
|
+
@visitor(Exp)
|
|
531
|
+
def _visit_exp(self, node: Exp):
|
|
532
|
+
"""Lower exponential function to JAX function."""
|
|
533
|
+
fO = self.lower(node.operand)
|
|
534
|
+
return lambda x, u, node, params: jnp.exp(fO(x, u, node, params))
|
|
535
|
+
|
|
536
|
+
@visitor(Log)
|
|
537
|
+
def _visit_log(self, node: Log):
|
|
538
|
+
"""Lower natural logarithm to JAX function."""
|
|
539
|
+
fO = self.lower(node.operand)
|
|
540
|
+
return lambda x, u, node, params: jnp.log(fO(x, u, node, params))
|
|
541
|
+
|
|
542
|
+
@visitor(Abs)
|
|
543
|
+
def _visit_abs(self, node: Abs):
|
|
544
|
+
"""Lower absolute value to JAX function."""
|
|
545
|
+
fO = self.lower(node.operand)
|
|
546
|
+
return lambda x, u, node, params: jnp.abs(fO(x, u, node, params))
|
|
547
|
+
|
|
548
|
+
@visitor(Equality)
|
|
549
|
+
@visitor(Inequality)
|
|
550
|
+
def _visit_constraint(self, node: Constraint):
|
|
551
|
+
"""Lower constraint to residual function.
|
|
552
|
+
|
|
553
|
+
Both equality (lhs == rhs) and inequality (lhs <= rhs) constraints are
|
|
554
|
+
lowered to their residual form: lhs - rhs. The constraint is satisfied
|
|
555
|
+
when the residual equals zero (equality) or is non-positive (inequality).
|
|
556
|
+
|
|
557
|
+
Args:
|
|
558
|
+
node: Equality or Inequality constraint node
|
|
559
|
+
|
|
560
|
+
Returns:
|
|
561
|
+
Function (x, u, node, params) -> lhs - rhs (constraint residual)
|
|
562
|
+
|
|
563
|
+
Note:
|
|
564
|
+
The returned residual is used in penalty methods and Lagrangian terms.
|
|
565
|
+
For equality: residual should be 0
|
|
566
|
+
For inequality: residual should be <= 0
|
|
567
|
+
"""
|
|
568
|
+
fL = self.lower(node.lhs)
|
|
569
|
+
fR = self.lower(node.rhs)
|
|
570
|
+
return lambda x, u, node, params: fL(x, u, node, params) - fR(x, u, node, params)
|
|
571
|
+
|
|
572
|
+
# TODO: (norrisg) CTCS is playing 2 roles here: both as a constraint wrapper and as the penalty
|
|
573
|
+
# expression w/ conditional logic. Consider adding conditional logic as separate AST nodes.
|
|
574
|
+
# Then, CTCS remains a wrapper and we just wrap the penalty expression with the conditional
|
|
575
|
+
# logic when we lower it.
|
|
576
|
+
@visitor(CTCS)
|
|
577
|
+
def _visit_ctcs(self, node: CTCS):
|
|
578
|
+
"""Lower CTCS (Continuous-Time Constraint Satisfaction) to JAX function.
|
|
579
|
+
|
|
580
|
+
CTCS constraints use penalty methods to enforce constraints over continuous
|
|
581
|
+
time intervals. The lowered function includes conditional logic to activate
|
|
582
|
+
the penalty only within the specified node interval.
|
|
583
|
+
|
|
584
|
+
Args:
|
|
585
|
+
node: CTCS constraint node with penalty expression and optional node range
|
|
586
|
+
|
|
587
|
+
Returns:
|
|
588
|
+
Function (x, u, current_node, params) -> penalty value or 0
|
|
589
|
+
|
|
590
|
+
Note:
|
|
591
|
+
Uses jax.lax.cond for JAX-traceable conditional evaluation. The penalty
|
|
592
|
+
is active only when current_node is in [start_node, end_node).
|
|
593
|
+
If no node range is specified, the penalty is always active.
|
|
594
|
+
|
|
595
|
+
See Also:
|
|
596
|
+
- CTCS: The symbolic CTCS constraint class
|
|
597
|
+
- penalty functions: PositivePart, Huber, SmoothReLU
|
|
598
|
+
"""
|
|
599
|
+
# Lower the penalty expression (which includes the constraint residual)
|
|
600
|
+
penalty_expr_fn = self.lower(node.penalty_expr())
|
|
601
|
+
|
|
602
|
+
def ctcs_fn(x, u, current_node, params):
|
|
603
|
+
# Check if constraint is active at this node
|
|
604
|
+
if node.nodes is not None:
|
|
605
|
+
start_node, end_node = node.nodes
|
|
606
|
+
# Extract scalar value from current_node (which may be array or scalar)
|
|
607
|
+
# Keep as JAX array for tracing compatibility
|
|
608
|
+
node_scalar = jnp.atleast_1d(current_node)[0]
|
|
609
|
+
is_active = (start_node <= node_scalar) & (node_scalar < end_node)
|
|
610
|
+
|
|
611
|
+
# Use jax.lax.cond for conditional evaluation
|
|
612
|
+
return cond(
|
|
613
|
+
is_active,
|
|
614
|
+
lambda _: penalty_expr_fn(x, u, current_node, params),
|
|
615
|
+
lambda _: 0.0,
|
|
616
|
+
operand=None,
|
|
617
|
+
)
|
|
618
|
+
else:
|
|
619
|
+
# Always active if no node range specified
|
|
620
|
+
return penalty_expr_fn(x, u, current_node, params)
|
|
621
|
+
|
|
622
|
+
return ctcs_fn
|
|
623
|
+
|
|
624
|
+
@visitor(PositivePart)
|
|
625
|
+
def _visit_pos(self, node):
|
|
626
|
+
"""Lower positive part function to JAX.
|
|
627
|
+
|
|
628
|
+
Computes max(x, 0), used in penalty methods for inequality constraints.
|
|
629
|
+
|
|
630
|
+
Args:
|
|
631
|
+
node: PositivePart expression node
|
|
632
|
+
|
|
633
|
+
Returns:
|
|
634
|
+
Function (x, u, node, params) -> max(operand, 0)
|
|
635
|
+
"""
|
|
636
|
+
f = self.lower(node.x)
|
|
637
|
+
return lambda x, u, node, params: jnp.maximum(f(x, u, node, params), 0.0)
|
|
638
|
+
|
|
639
|
+
@visitor(Square)
|
|
640
|
+
def _visit_square(self, node):
|
|
641
|
+
"""Lower square function to JAX.
|
|
642
|
+
|
|
643
|
+
Computes x^2 element-wise. Used in quadratic penalty methods.
|
|
644
|
+
|
|
645
|
+
Args:
|
|
646
|
+
node: Square expression node
|
|
647
|
+
|
|
648
|
+
Returns:
|
|
649
|
+
Function (x, u, node, params) -> operand^2
|
|
650
|
+
"""
|
|
651
|
+
f = self.lower(node.x)
|
|
652
|
+
return lambda x, u, node, params: f(x, u, node, params) * f(x, u, node, params)
|
|
653
|
+
|
|
654
|
+
@visitor(Huber)
|
|
655
|
+
def _visit_huber(self, node):
|
|
656
|
+
"""Lower Huber penalty function to JAX.
|
|
657
|
+
|
|
658
|
+
Huber penalty is quadratic for small values and linear for large values:
|
|
659
|
+
- |x| <= delta: 0.5 * x^2
|
|
660
|
+
- |x| > delta: delta * (|x| - 0.5 * delta)
|
|
661
|
+
|
|
662
|
+
Args:
|
|
663
|
+
node: Huber expression node with delta parameter
|
|
664
|
+
|
|
665
|
+
Returns:
|
|
666
|
+
Function (x, u, node, params) -> Huber penalty
|
|
667
|
+
"""
|
|
668
|
+
f = self.lower(node.x)
|
|
669
|
+
delta = node.delta
|
|
670
|
+
return lambda x, u, node, params: jnp.where(
|
|
671
|
+
jnp.abs(f(x, u, node, params)) <= delta,
|
|
672
|
+
0.5 * f(x, u, node, params) ** 2,
|
|
673
|
+
delta * (jnp.abs(f(x, u, node, params)) - 0.5 * delta),
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
@visitor(SmoothReLU)
|
|
677
|
+
def _visit_srelu(self, node):
|
|
678
|
+
"""Lower smooth ReLU penalty function to JAX.
|
|
679
|
+
|
|
680
|
+
Smooth approximation to ReLU: sqrt(max(x, 0)^2 + c^2) - c
|
|
681
|
+
Differentiable everywhere, approaches ReLU as c -> 0.
|
|
682
|
+
|
|
683
|
+
Args:
|
|
684
|
+
node: SmoothReLU expression node with smoothing parameter c
|
|
685
|
+
|
|
686
|
+
Returns:
|
|
687
|
+
Function (x, u, node, params) -> smooth ReLU penalty
|
|
688
|
+
"""
|
|
689
|
+
f = self.lower(node.x)
|
|
690
|
+
c = node.c
|
|
691
|
+
# smooth_relu(pos(x)) = sqrt(pos(x)^2 + c^2) - c ; here f already includes pos inside node
|
|
692
|
+
return (
|
|
693
|
+
lambda x, u, node, params: jnp.sqrt(jnp.maximum(f(x, u, node, params), 0.0) ** 2 + c**2)
|
|
694
|
+
- c
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
@visitor(NodalConstraint)
|
|
698
|
+
def _visit_nodal_constraint(self, node: NodalConstraint):
|
|
699
|
+
"""Lower a NodalConstraint by lowering its underlying constraint.
|
|
700
|
+
|
|
701
|
+
NodalConstraint is a wrapper that specifies which nodes a constraint
|
|
702
|
+
applies to. The lowering just unwraps and lowers the inner constraint.
|
|
703
|
+
|
|
704
|
+
Args:
|
|
705
|
+
node: NodalConstraint wrapper
|
|
706
|
+
|
|
707
|
+
Returns:
|
|
708
|
+
Function from lowering the wrapped constraint expression
|
|
709
|
+
"""
|
|
710
|
+
return self.lower(node.constraint)
|
|
711
|
+
|
|
712
|
+
@visitor(Sqrt)
|
|
713
|
+
def _visit_sqrt(self, node: Sqrt):
|
|
714
|
+
"""Lower square root to JAX function."""
|
|
715
|
+
f = self.lower(node.operand)
|
|
716
|
+
return lambda x, u, node, params: jnp.sqrt(f(x, u, node, params))
|
|
717
|
+
|
|
718
|
+
@visitor(Max)
|
|
719
|
+
def _visit_max(self, node: Max):
|
|
720
|
+
"""Lower element-wise maximum to JAX function."""
|
|
721
|
+
fs = [self.lower(op) for op in node.operands]
|
|
722
|
+
|
|
723
|
+
def fn(x, u, node, params):
|
|
724
|
+
values = [f(x, u, node, params) for f in fs]
|
|
725
|
+
# jnp.maximum can take multiple arguments
|
|
726
|
+
result = values[0]
|
|
727
|
+
for val in values[1:]:
|
|
728
|
+
result = jnp.maximum(result, val)
|
|
729
|
+
return result
|
|
730
|
+
|
|
731
|
+
return fn
|
|
732
|
+
|
|
733
|
+
@visitor(LogSumExp)
|
|
734
|
+
def _visit_logsumexp(self, node: LogSumExp):
|
|
735
|
+
"""Lower log-sum-exp to JAX function.
|
|
736
|
+
|
|
737
|
+
Computes log(sum(exp(x_i))) for multiple operands, which is a smooth
|
|
738
|
+
approximation to the maximum function. Uses JAX's numerically stable
|
|
739
|
+
logsumexp implementation. Performs element-wise log-sum-exp with
|
|
740
|
+
broadcasting support.
|
|
741
|
+
"""
|
|
742
|
+
fs = [self.lower(op) for op in node.operands]
|
|
743
|
+
|
|
744
|
+
def fn(x, u, node, params):
|
|
745
|
+
values = [f(x, u, node, params) for f in fs]
|
|
746
|
+
# Broadcast all values to the same shape, then stack along new axis
|
|
747
|
+
# and compute logsumexp along that axis for element-wise operation
|
|
748
|
+
broadcasted = jnp.broadcast_arrays(*values)
|
|
749
|
+
stacked = jnp.stack(list(broadcasted), axis=0)
|
|
750
|
+
return logsumexp(stacked, axis=0)
|
|
751
|
+
|
|
752
|
+
return fn
|
|
753
|
+
|
|
754
|
+
@visitor(Transpose)
|
|
755
|
+
def _visit_transpose(self, node: Transpose):
|
|
756
|
+
"""Lower matrix transpose to JAX function."""
|
|
757
|
+
f = self.lower(node.operand)
|
|
758
|
+
return lambda x, u, node, params: jnp.transpose(f(x, u, node, params))
|
|
759
|
+
|
|
760
|
+
@visitor(Power)
|
|
761
|
+
def _visit_power(self, node: Power):
|
|
762
|
+
"""Lower element-wise power (base**exponent) to JAX function."""
|
|
763
|
+
fB = self.lower(node.base)
|
|
764
|
+
fE = self.lower(node.exponent)
|
|
765
|
+
return lambda x, u, node, params: jnp.power(fB(x, u, node, params), fE(x, u, node, params))
|
|
766
|
+
|
|
767
|
+
@visitor(Stack)
|
|
768
|
+
def _visit_stack(self, node: Stack):
|
|
769
|
+
"""Lower vertical stacking to JAX function (stack along axis 0)."""
|
|
770
|
+
row_fns = [self.lower(row) for row in node.rows]
|
|
771
|
+
|
|
772
|
+
def stack_fn(x, u, node, params):
|
|
773
|
+
rows = [jnp.atleast_1d(fn(x, u, node, params)) for fn in row_fns]
|
|
774
|
+
return jnp.stack(rows, axis=0)
|
|
775
|
+
|
|
776
|
+
return stack_fn
|
|
777
|
+
|
|
778
|
+
@visitor(Hstack)
|
|
779
|
+
def _visit_hstack(self, node: Hstack):
|
|
780
|
+
"""Lower horizontal stacking to JAX function."""
|
|
781
|
+
array_fns = [self.lower(arr) for arr in node.arrays]
|
|
782
|
+
|
|
783
|
+
def hstack_fn(x, u, node, params):
|
|
784
|
+
arrays = [jnp.atleast_1d(fn(x, u, node, params)) for fn in array_fns]
|
|
785
|
+
return jnp.hstack(arrays)
|
|
786
|
+
|
|
787
|
+
return hstack_fn
|
|
788
|
+
|
|
789
|
+
@visitor(Vstack)
|
|
790
|
+
def _visit_vstack(self, node: Vstack):
|
|
791
|
+
"""Lower vertical stacking to JAX function."""
|
|
792
|
+
array_fns = [self.lower(arr) for arr in node.arrays]
|
|
793
|
+
|
|
794
|
+
def vstack_fn(x, u, node, params):
|
|
795
|
+
arrays = [jnp.atleast_1d(fn(x, u, node, params)) for fn in array_fns]
|
|
796
|
+
return jnp.vstack(arrays)
|
|
797
|
+
|
|
798
|
+
return vstack_fn
|
|
799
|
+
|
|
800
|
+
@visitor(Block)
|
|
801
|
+
def _visit_block(self, node: Block):
|
|
802
|
+
"""Lower block matrix construction to JAX function.
|
|
803
|
+
|
|
804
|
+
Assembles a block matrix from nested lists of expressions. For 2D blocks,
|
|
805
|
+
uses jnp.block directly. For N-D blocks (3D+), manually assembles along
|
|
806
|
+
the first two dimensions using concatenate, since jnp.block concatenates
|
|
807
|
+
along the last axes (not what we want for block matrix semantics).
|
|
808
|
+
|
|
809
|
+
Args:
|
|
810
|
+
node: Block expression node with 2D nested structure of expressions
|
|
811
|
+
|
|
812
|
+
Returns:
|
|
813
|
+
Function (x, u, node, params) -> assembled block matrix/tensor
|
|
814
|
+
"""
|
|
815
|
+
# Lower each block expression
|
|
816
|
+
block_fns = [[self.lower(block) for block in row] for row in node.blocks]
|
|
817
|
+
|
|
818
|
+
def block_fn(x, u, node_arg, params):
|
|
819
|
+
# Evaluate all blocks
|
|
820
|
+
block_values = [
|
|
821
|
+
[jnp.atleast_1d(fn(x, u, node_arg, params)) for fn in row] for row in block_fns
|
|
822
|
+
]
|
|
823
|
+
|
|
824
|
+
# Check if any block is 3D+ (need manual assembly)
|
|
825
|
+
max_ndim = max(arr.ndim for row in block_values for arr in row)
|
|
826
|
+
|
|
827
|
+
if max_ndim <= 2:
|
|
828
|
+
# For 2D, jnp.block works correctly
|
|
829
|
+
return jnp.block(block_values)
|
|
830
|
+
else:
|
|
831
|
+
# For N-D, manually assemble along axes 0 and 1
|
|
832
|
+
# First, ensure all blocks have the same number of dimensions
|
|
833
|
+
def promote_to_ndim(arr, target_ndim):
|
|
834
|
+
while arr.ndim < target_ndim:
|
|
835
|
+
arr = jnp.expand_dims(arr, axis=0)
|
|
836
|
+
return arr
|
|
837
|
+
|
|
838
|
+
block_values = [
|
|
839
|
+
[promote_to_ndim(arr, max_ndim) for arr in row] for row in block_values
|
|
840
|
+
]
|
|
841
|
+
|
|
842
|
+
# Concatenate each row along axis 1 (horizontal)
|
|
843
|
+
row_results = [jnp.concatenate(row, axis=1) for row in block_values]
|
|
844
|
+
# Concatenate rows along axis 0 (vertical)
|
|
845
|
+
return jnp.concatenate(row_results, axis=0)
|
|
846
|
+
|
|
847
|
+
return block_fn
|
|
848
|
+
|
|
849
|
+
@visitor(QDCM)
|
|
850
|
+
def _visit_qdcm(self, node: QDCM):
|
|
851
|
+
"""Lower quaternion to direction cosine matrix (DCM) conversion.
|
|
852
|
+
|
|
853
|
+
Converts a unit quaternion [q0, q1, q2, q3] to a 3x3 rotation matrix.
|
|
854
|
+
Used in 6-DOF spacecraft and robotics applications.
|
|
855
|
+
|
|
856
|
+
The quaternion is normalized before conversion to ensure a valid rotation
|
|
857
|
+
matrix. The DCM is computed using the standard quaternion-to-DCM formula.
|
|
858
|
+
|
|
859
|
+
Args:
|
|
860
|
+
node: QDCM expression node
|
|
861
|
+
|
|
862
|
+
Returns:
|
|
863
|
+
Function (x, u, node, params) -> 3x3 rotation matrix
|
|
864
|
+
|
|
865
|
+
Note:
|
|
866
|
+
Quaternion convention: [w, x, y, z] where w is the scalar part
|
|
867
|
+
"""
|
|
868
|
+
f = self.lower(node.q)
|
|
869
|
+
|
|
870
|
+
def qdcm_fn(x, u, node, params):
|
|
871
|
+
q = f(x, u, node, params)
|
|
872
|
+
# Normalize the quaternion
|
|
873
|
+
q_norm = jnp.sqrt(q[0] ** 2 + q[1] ** 2 + q[2] ** 2 + q[3] ** 2)
|
|
874
|
+
w, qx, qy, qz = q / q_norm
|
|
875
|
+
# Convert to direction cosine matrix
|
|
876
|
+
return jnp.array(
|
|
877
|
+
[
|
|
878
|
+
[1 - 2 * (qy**2 + qz**2), 2 * (qx * qy - qz * w), 2 * (qx * qz + qy * w)],
|
|
879
|
+
[2 * (qx * qy + qz * w), 1 - 2 * (qx**2 + qz**2), 2 * (qy * qz - qx * w)],
|
|
880
|
+
[2 * (qx * qz - qy * w), 2 * (qy * qz + qx * w), 1 - 2 * (qx**2 + qy**2)],
|
|
881
|
+
]
|
|
882
|
+
)
|
|
883
|
+
|
|
884
|
+
return qdcm_fn
|
|
885
|
+
|
|
886
|
+
@visitor(SSMP)
|
|
887
|
+
def _visit_ssmp(self, node: SSMP):
|
|
888
|
+
"""Lower skew-symmetric matrix for quaternion dynamics (4x4).
|
|
889
|
+
|
|
890
|
+
Creates a 4x4 skew-symmetric matrix from angular velocity vector for
|
|
891
|
+
quaternion kinematic propagation: q_dot = 0.5 * SSMP(omega) @ q
|
|
892
|
+
|
|
893
|
+
The SSMP matrix is used in quaternion kinematics to compute quaternion
|
|
894
|
+
derivatives from angular velocity vectors.
|
|
895
|
+
|
|
896
|
+
Args:
|
|
897
|
+
node: SSMP expression node
|
|
898
|
+
|
|
899
|
+
Returns:
|
|
900
|
+
Function (x, u, node, params) -> 4x4 skew-symmetric matrix
|
|
901
|
+
|
|
902
|
+
Note:
|
|
903
|
+
For angular velocity w = [x, y, z], returns:
|
|
904
|
+
[[0, -x, -y, -z],
|
|
905
|
+
[x, 0, z, -y],
|
|
906
|
+
[y, -z, 0, x],
|
|
907
|
+
[z, y, -x, 0]]
|
|
908
|
+
"""
|
|
909
|
+
f = self.lower(node.w)
|
|
910
|
+
|
|
911
|
+
def ssmp_fn(x, u, node, params):
|
|
912
|
+
w = f(x, u, node, params)
|
|
913
|
+
wx, wy, wz = w[0], w[1], w[2]
|
|
914
|
+
return jnp.array(
|
|
915
|
+
[
|
|
916
|
+
[0, -wx, -wy, -wz],
|
|
917
|
+
[wx, 0, wz, -wy],
|
|
918
|
+
[wy, -wz, 0, wx],
|
|
919
|
+
[wz, wy, -wx, 0],
|
|
920
|
+
]
|
|
921
|
+
)
|
|
922
|
+
|
|
923
|
+
return ssmp_fn
|
|
924
|
+
|
|
925
|
+
@visitor(SSM)
|
|
926
|
+
def _visit_ssm(self, node: SSM):
|
|
927
|
+
"""Lower skew-symmetric matrix for cross product (3x3).
|
|
928
|
+
|
|
929
|
+
Creates a 3x3 skew-symmetric matrix from a vector such that
|
|
930
|
+
SSM(a) @ b = a x b (cross product).
|
|
931
|
+
|
|
932
|
+
The SSM is the matrix representation of the cross product operator,
|
|
933
|
+
allowing cross products to be computed as matrix-vector multiplication.
|
|
934
|
+
|
|
935
|
+
Args:
|
|
936
|
+
node: SSM expression node
|
|
937
|
+
|
|
938
|
+
Returns:
|
|
939
|
+
Function (x, u, node, params) -> 3x3 skew-symmetric matrix
|
|
940
|
+
|
|
941
|
+
Note:
|
|
942
|
+
For vector w = [x, y, z], returns:
|
|
943
|
+
[[ 0, -z, y],
|
|
944
|
+
[ z, 0, -x],
|
|
945
|
+
[-y, x, 0]]
|
|
946
|
+
"""
|
|
947
|
+
f = self.lower(node.w)
|
|
948
|
+
|
|
949
|
+
def ssm_fn(x, u, node, params):
|
|
950
|
+
w = f(x, u, node, params)
|
|
951
|
+
wx, wy, wz = w[0], w[1], w[2]
|
|
952
|
+
return jnp.array([[0, -wz, wy], [wz, 0, -wx], [-wy, wx, 0]])
|
|
953
|
+
|
|
954
|
+
return ssm_fn
|
|
955
|
+
|
|
956
|
+
@visitor(AdjointDual)
|
|
957
|
+
def _visit_adjoint_dual(self, node: AdjointDual):
|
|
958
|
+
"""Lower coadjoint operator ad* for rigid body dynamics.
|
|
959
|
+
|
|
960
|
+
Computes the coadjoint action ad*_ξ(μ) which represents Coriolis and
|
|
961
|
+
centrifugal forces in rigid body dynamics. This is the key term in
|
|
962
|
+
Newton-Euler equations.
|
|
963
|
+
|
|
964
|
+
For se(3), given twist ξ = [v; ω] and momentum μ = [f; τ]:
|
|
965
|
+
|
|
966
|
+
ad*_ξ(μ) = [ ω × f + v × τ ]
|
|
967
|
+
[ ω × τ ]
|
|
968
|
+
|
|
969
|
+
This appears in the equations of motion as:
|
|
970
|
+
M @ ξ_dot = F_ext - ad*_ξ(M @ ξ)
|
|
971
|
+
|
|
972
|
+
Args:
|
|
973
|
+
node: AdjointDual expression node
|
|
974
|
+
|
|
975
|
+
Returns:
|
|
976
|
+
Function (x, u, node, params) -> 6D coadjoint result
|
|
977
|
+
|
|
978
|
+
Note:
|
|
979
|
+
Convention: twist = [v; ω] (linear velocity, angular velocity)
|
|
980
|
+
momentum = [f; τ] (force, torque)
|
|
981
|
+
"""
|
|
982
|
+
f_twist = self.lower(node.twist)
|
|
983
|
+
f_momentum = self.lower(node.momentum)
|
|
984
|
+
|
|
985
|
+
def adjoint_dual_fn(x, u, node, params):
|
|
986
|
+
twist = f_twist(x, u, node, params)
|
|
987
|
+
momentum = f_momentum(x, u, node, params)
|
|
988
|
+
|
|
989
|
+
# Extract components: twist = [v; ω], momentum = [f; τ]
|
|
990
|
+
v = twist[:3] # Linear velocity
|
|
991
|
+
omega = twist[3:] # Angular velocity
|
|
992
|
+
f = momentum[:3] # Force (or linear momentum)
|
|
993
|
+
tau = momentum[3:] # Torque (or angular momentum)
|
|
994
|
+
|
|
995
|
+
# Coadjoint action: ad*_ξ(μ) = [ω × f + v × τ; ω × τ]
|
|
996
|
+
linear_part = jnp.cross(omega, f) + jnp.cross(v, tau)
|
|
997
|
+
angular_part = jnp.cross(omega, tau)
|
|
998
|
+
|
|
999
|
+
return jnp.concatenate([linear_part, angular_part])
|
|
1000
|
+
|
|
1001
|
+
return adjoint_dual_fn
|
|
1002
|
+
|
|
1003
|
+
@visitor(Adjoint)
|
|
1004
|
+
def _visit_adjoint(self, node: Adjoint):
|
|
1005
|
+
"""Lower adjoint operator ad (Lie bracket) for twist-on-twist action.
|
|
1006
|
+
|
|
1007
|
+
Computes the adjoint action ad_ξ₁(ξ₂) which represents the Lie bracket
|
|
1008
|
+
[ξ₁, ξ₂] of two twists. Used for velocity propagation and acceleration
|
|
1009
|
+
computation in kinematic chains.
|
|
1010
|
+
|
|
1011
|
+
For se(3), given twists ξ₁ = [v₁; ω₁] and ξ₂ = [v₂; ω₂]:
|
|
1012
|
+
|
|
1013
|
+
ad_ξ₁(ξ₂) = [ ω₁ × v₂ - ω₂ × v₁ ]
|
|
1014
|
+
[ ω₁ × ω₂ ]
|
|
1015
|
+
|
|
1016
|
+
Args:
|
|
1017
|
+
node: Adjoint expression node
|
|
1018
|
+
|
|
1019
|
+
Returns:
|
|
1020
|
+
Function (x, u, node, params) -> 6D Lie bracket result
|
|
1021
|
+
|
|
1022
|
+
Note:
|
|
1023
|
+
The Lie bracket is antisymmetric: [ξ₁, ξ₂] = -[ξ₂, ξ₁]
|
|
1024
|
+
"""
|
|
1025
|
+
f_twist1 = self.lower(node.twist1)
|
|
1026
|
+
f_twist2 = self.lower(node.twist2)
|
|
1027
|
+
|
|
1028
|
+
def adjoint_fn(x, u, node, params):
|
|
1029
|
+
twist1 = f_twist1(x, u, node, params)
|
|
1030
|
+
twist2 = f_twist2(x, u, node, params)
|
|
1031
|
+
|
|
1032
|
+
# Extract components: twist = [v; ω]
|
|
1033
|
+
v1 = twist1[:3]
|
|
1034
|
+
omega1 = twist1[3:]
|
|
1035
|
+
v2 = twist2[:3]
|
|
1036
|
+
omega2 = twist2[3:]
|
|
1037
|
+
|
|
1038
|
+
# Lie bracket: [ξ₁, ξ₂] = [ω₁ × v₂ - ω₂ × v₁; ω₁ × ω₂]
|
|
1039
|
+
linear_part = jnp.cross(omega1, v2) - jnp.cross(omega2, v1)
|
|
1040
|
+
angular_part = jnp.cross(omega1, omega2)
|
|
1041
|
+
|
|
1042
|
+
return jnp.concatenate([linear_part, angular_part])
|
|
1043
|
+
|
|
1044
|
+
return adjoint_fn
|
|
1045
|
+
|
|
1046
|
+
@visitor(SE3Adjoint)
|
|
1047
|
+
def _visit_se3_adjoint(self, node: SE3Adjoint):
|
|
1048
|
+
"""Lower SE3 Adjoint (big Ad) for transforming twists between frames.
|
|
1049
|
+
|
|
1050
|
+
Computes the 6×6 adjoint matrix Ad_T that transforms twists:
|
|
1051
|
+
ξ_b = Ad_{T_ab} @ ξ_a
|
|
1052
|
+
|
|
1053
|
+
For SE(3) with rotation R and translation p:
|
|
1054
|
+
Ad_T = [ R 0 ]
|
|
1055
|
+
[ [p]×R R ]
|
|
1056
|
+
|
|
1057
|
+
Args:
|
|
1058
|
+
node: SE3Adjoint expression node
|
|
1059
|
+
|
|
1060
|
+
Returns:
|
|
1061
|
+
Function (x, u, node, params) -> 6×6 adjoint matrix
|
|
1062
|
+
"""
|
|
1063
|
+
f_transform = self.lower(node.transform)
|
|
1064
|
+
|
|
1065
|
+
def se3_adjoint_fn(x, u, node, params):
|
|
1066
|
+
T = f_transform(x, u, node, params)
|
|
1067
|
+
|
|
1068
|
+
# Extract rotation and translation from 4×4 homogeneous matrix
|
|
1069
|
+
R = T[:3, :3]
|
|
1070
|
+
p = T[:3, 3]
|
|
1071
|
+
|
|
1072
|
+
# Build skew-symmetric matrix [p]×
|
|
1073
|
+
p_skew = jnp.array([[0, -p[2], p[1]], [p[2], 0, -p[0]], [-p[1], p[0], 0]])
|
|
1074
|
+
|
|
1075
|
+
# Build 6×6 adjoint matrix
|
|
1076
|
+
# Ad_T = [ R 0 ]
|
|
1077
|
+
# [ [p]×R R ]
|
|
1078
|
+
top_row = jnp.hstack([R, jnp.zeros((3, 3))])
|
|
1079
|
+
bottom_row = jnp.hstack([p_skew @ R, R])
|
|
1080
|
+
|
|
1081
|
+
return jnp.vstack([top_row, bottom_row])
|
|
1082
|
+
|
|
1083
|
+
return se3_adjoint_fn
|
|
1084
|
+
|
|
1085
|
+
@visitor(SE3AdjointDual)
|
|
1086
|
+
def _visit_se3_adjoint_dual(self, node: SE3AdjointDual):
|
|
1087
|
+
"""Lower SE3 coadjoint (big Ad*) for transforming wrenches between frames.
|
|
1088
|
+
|
|
1089
|
+
Computes the 6×6 coadjoint matrix Ad*_T that transforms wrenches:
|
|
1090
|
+
F_a = Ad*_{T_ab} @ F_b
|
|
1091
|
+
|
|
1092
|
+
For SE(3) with rotation R and translation p:
|
|
1093
|
+
Ad*_T = [ R [p]×R ]
|
|
1094
|
+
[ 0 R ]
|
|
1095
|
+
|
|
1096
|
+
This is the transpose-inverse of Ad_T.
|
|
1097
|
+
|
|
1098
|
+
Args:
|
|
1099
|
+
node: SE3AdjointDual expression node
|
|
1100
|
+
|
|
1101
|
+
Returns:
|
|
1102
|
+
Function (x, u, node, params) -> 6×6 coadjoint matrix
|
|
1103
|
+
"""
|
|
1104
|
+
f_transform = self.lower(node.transform)
|
|
1105
|
+
|
|
1106
|
+
def se3_adjoint_dual_fn(x, u, node, params):
|
|
1107
|
+
T = f_transform(x, u, node, params)
|
|
1108
|
+
|
|
1109
|
+
# Extract rotation and translation from 4×4 homogeneous matrix
|
|
1110
|
+
R = T[:3, :3]
|
|
1111
|
+
p = T[:3, 3]
|
|
1112
|
+
|
|
1113
|
+
# Build skew-symmetric matrix [p]×
|
|
1114
|
+
p_skew = jnp.array([[0, -p[2], p[1]], [p[2], 0, -p[0]], [-p[1], p[0], 0]])
|
|
1115
|
+
|
|
1116
|
+
# Build 6×6 coadjoint matrix
|
|
1117
|
+
# Ad*_T = [ R [p]×R ]
|
|
1118
|
+
# [ 0 R ]
|
|
1119
|
+
top_row = jnp.hstack([R, p_skew @ R])
|
|
1120
|
+
bottom_row = jnp.hstack([jnp.zeros((3, 3)), R])
|
|
1121
|
+
|
|
1122
|
+
return jnp.vstack([top_row, bottom_row])
|
|
1123
|
+
|
|
1124
|
+
return se3_adjoint_dual_fn
|
|
1125
|
+
|
|
1126
|
+
@visitor(SO3Exp)
|
|
1127
|
+
def _visit_so3_exp(self, node: SO3Exp):
|
|
1128
|
+
"""Lower SO3 exponential map using jaxlie.
|
|
1129
|
+
|
|
1130
|
+
Maps a 3D rotation vector (axis-angle) to a 3×3 rotation matrix
|
|
1131
|
+
using jaxlie's numerically robust implementation.
|
|
1132
|
+
|
|
1133
|
+
Args:
|
|
1134
|
+
node: SO3Exp expression node
|
|
1135
|
+
|
|
1136
|
+
Returns:
|
|
1137
|
+
Function (x, u, node, params) -> 3×3 rotation matrix
|
|
1138
|
+
"""
|
|
1139
|
+
import jaxlie
|
|
1140
|
+
|
|
1141
|
+
f_omega = self.lower(node.omega)
|
|
1142
|
+
|
|
1143
|
+
def so3_exp_fn(x, u, node, params):
|
|
1144
|
+
omega = f_omega(x, u, node, params)
|
|
1145
|
+
return jaxlie.SO3.exp(omega).as_matrix()
|
|
1146
|
+
|
|
1147
|
+
return so3_exp_fn
|
|
1148
|
+
|
|
1149
|
+
@visitor(SO3Log)
|
|
1150
|
+
def _visit_so3_log(self, node: SO3Log):
|
|
1151
|
+
"""Lower SO3 logarithm map using jaxlie.
|
|
1152
|
+
|
|
1153
|
+
Maps a 3×3 rotation matrix to a 3D rotation vector (axis-angle)
|
|
1154
|
+
using jaxlie's numerically robust implementation.
|
|
1155
|
+
|
|
1156
|
+
Args:
|
|
1157
|
+
node: SO3Log expression node
|
|
1158
|
+
|
|
1159
|
+
Returns:
|
|
1160
|
+
Function (x, u, node, params) -> 3D rotation vector
|
|
1161
|
+
"""
|
|
1162
|
+
import jaxlie
|
|
1163
|
+
|
|
1164
|
+
f_rotation = self.lower(node.rotation)
|
|
1165
|
+
|
|
1166
|
+
def so3_log_fn(x, u, node, params):
|
|
1167
|
+
rotation = f_rotation(x, u, node, params)
|
|
1168
|
+
return jaxlie.SO3.from_matrix(rotation).log()
|
|
1169
|
+
|
|
1170
|
+
return so3_log_fn
|
|
1171
|
+
|
|
1172
|
+
@visitor(SE3Exp)
|
|
1173
|
+
def _visit_se3_exp(self, node: SE3Exp):
|
|
1174
|
+
"""Lower SE3 exponential map using jaxlie.
|
|
1175
|
+
|
|
1176
|
+
Maps a 6D twist vector [v; ω] to a 4×4 homogeneous transformation
|
|
1177
|
+
matrix using jaxlie's numerically robust implementation.
|
|
1178
|
+
|
|
1179
|
+
The twist convention [v; ω] (linear first, angular second) matches
|
|
1180
|
+
jaxlie's SE3 tangent parameterization, so no reordering is needed.
|
|
1181
|
+
|
|
1182
|
+
Args:
|
|
1183
|
+
node: SE3Exp expression node
|
|
1184
|
+
|
|
1185
|
+
Returns:
|
|
1186
|
+
Function (x, u, node, params) -> 4×4 transformation matrix
|
|
1187
|
+
"""
|
|
1188
|
+
import jaxlie
|
|
1189
|
+
|
|
1190
|
+
f_twist = self.lower(node.twist)
|
|
1191
|
+
|
|
1192
|
+
def se3_exp_fn(x, u, node, params):
|
|
1193
|
+
twist = f_twist(x, u, node, params)
|
|
1194
|
+
return jaxlie.SE3.exp(twist).as_matrix()
|
|
1195
|
+
|
|
1196
|
+
return se3_exp_fn
|
|
1197
|
+
|
|
1198
|
+
@visitor(SE3Log)
|
|
1199
|
+
def _visit_se3_log(self, node: SE3Log):
|
|
1200
|
+
"""Lower SE3 logarithm map using jaxlie.
|
|
1201
|
+
|
|
1202
|
+
Maps a 4×4 homogeneous transformation matrix to a 6D twist vector
|
|
1203
|
+
[v; ω] using jaxlie's numerically robust implementation.
|
|
1204
|
+
|
|
1205
|
+
Args:
|
|
1206
|
+
node: SE3Log expression node
|
|
1207
|
+
|
|
1208
|
+
Returns:
|
|
1209
|
+
Function (x, u, node, params) -> 6D twist vector
|
|
1210
|
+
"""
|
|
1211
|
+
import jaxlie
|
|
1212
|
+
|
|
1213
|
+
f_transform = self.lower(node.transform)
|
|
1214
|
+
|
|
1215
|
+
def se3_log_fn(x, u, node, params):
|
|
1216
|
+
transform = f_transform(x, u, node, params)
|
|
1217
|
+
return jaxlie.SE3.from_matrix(transform).log()
|
|
1218
|
+
|
|
1219
|
+
return se3_log_fn
|
|
1220
|
+
|
|
1221
|
+
@visitor(Diag)
|
|
1222
|
+
def _visit_diag(self, node: Diag):
|
|
1223
|
+
"""Lower diagonal matrix construction to JAX function."""
|
|
1224
|
+
f = self.lower(node.operand)
|
|
1225
|
+
return lambda x, u, node, params: jnp.diag(f(x, u, node, params))
|
|
1226
|
+
|
|
1227
|
+
@visitor(Or)
|
|
1228
|
+
def _visit_or(self, node: Or):
|
|
1229
|
+
"""Lower STL disjunction (Or) to JAX using STLJax library.
|
|
1230
|
+
|
|
1231
|
+
Converts a symbolic Or constraint to an STLJax Or formula for handling
|
|
1232
|
+
disjunctive task specifications. Each operand becomes an STLJax predicate.
|
|
1233
|
+
|
|
1234
|
+
Args:
|
|
1235
|
+
node: Or expression node with multiple operands
|
|
1236
|
+
|
|
1237
|
+
Returns:
|
|
1238
|
+
Function (x, u, node, params) -> STL robustness value
|
|
1239
|
+
|
|
1240
|
+
Note:
|
|
1241
|
+
Uses STLJax library for signal temporal logic evaluation. The returned
|
|
1242
|
+
function computes the robustness metric for the disjunction, which is
|
|
1243
|
+
positive when at least one operand is satisfied.
|
|
1244
|
+
|
|
1245
|
+
Example:
|
|
1246
|
+
Used for task specifications like "reach goal A OR goal B"::
|
|
1247
|
+
|
|
1248
|
+
goal_A = ox.Norm(x - target_A) <= 1.0
|
|
1249
|
+
goal_B = ox.Norm(x - target_B) <= 1.0
|
|
1250
|
+
task = ox.Or(goal_A, goal_B)
|
|
1251
|
+
|
|
1252
|
+
See Also:
|
|
1253
|
+
- stljax.formula.Or: Underlying STLJax implementation
|
|
1254
|
+
- STL robustness: Quantitative measure of constraint satisfaction
|
|
1255
|
+
"""
|
|
1256
|
+
from stljax.formula import Or as STLOr
|
|
1257
|
+
from stljax.formula import Predicate
|
|
1258
|
+
|
|
1259
|
+
# Lower each operand to get their functions
|
|
1260
|
+
operand_fns = [self.lower(operand) for operand in node.operands]
|
|
1261
|
+
|
|
1262
|
+
# Return a function that evaluates the STLJax Or
|
|
1263
|
+
def or_fn(x, u, node, params):
|
|
1264
|
+
# Create STLJax predicates for each operand with current params
|
|
1265
|
+
predicates = []
|
|
1266
|
+
for i, operand_fn in enumerate(operand_fns):
|
|
1267
|
+
# Create a predicate function that captures the current params
|
|
1268
|
+
def make_pred_fn(fn):
|
|
1269
|
+
return lambda x: fn(x, None, None, params)
|
|
1270
|
+
|
|
1271
|
+
pred_fn = make_pred_fn(operand_fn)
|
|
1272
|
+
predicates.append(Predicate(f"pred_{i}", pred_fn))
|
|
1273
|
+
|
|
1274
|
+
# Create and evaluate STLJax Or formula
|
|
1275
|
+
stl_or = STLOr(*predicates)
|
|
1276
|
+
return stl_or(x)
|
|
1277
|
+
|
|
1278
|
+
return or_fn
|
|
1279
|
+
|
|
1280
|
+
@visitor(NodeReference)
|
|
1281
|
+
def _visit_node_reference(self, node: NodeReference):
|
|
1282
|
+
"""Lower NodeReference - extract value at a specific trajectory node.
|
|
1283
|
+
|
|
1284
|
+
NodeReference extracts a state/control value at a specific node from the
|
|
1285
|
+
full trajectory arrays. The node index is baked into the lowered function.
|
|
1286
|
+
|
|
1287
|
+
Args:
|
|
1288
|
+
node: NodeReference expression with base and node_idx (integer)
|
|
1289
|
+
|
|
1290
|
+
Returns:
|
|
1291
|
+
Function (x, u, node_param, params) that extracts from trajectory
|
|
1292
|
+
- x, u: Full trajectories (N, n_x) and (N, n_u)
|
|
1293
|
+
- node_param: Unused (kept for signature compatibility)
|
|
1294
|
+
- params: Problem parameters
|
|
1295
|
+
|
|
1296
|
+
Example:
|
|
1297
|
+
position.at(5) lowers to a function that extracts x[5, position_slice]
|
|
1298
|
+
position.at(k-1) where k=7 lowers to extract x[6, position_slice]
|
|
1299
|
+
"""
|
|
1300
|
+
from openscvx.symbolic.expr.control import Control
|
|
1301
|
+
from openscvx.symbolic.expr.state import State
|
|
1302
|
+
|
|
1303
|
+
# Node index is baked into the expression at construction time
|
|
1304
|
+
fixed_idx = node.node_idx
|
|
1305
|
+
|
|
1306
|
+
if isinstance(node.base, State):
|
|
1307
|
+
sl = node.base._slice
|
|
1308
|
+
if sl is None:
|
|
1309
|
+
raise ValueError(f"State {node.base.name!r} has no slice assigned")
|
|
1310
|
+
|
|
1311
|
+
def state_node_fn(x, u, node_param, params):
|
|
1312
|
+
return x[fixed_idx, sl]
|
|
1313
|
+
|
|
1314
|
+
return state_node_fn
|
|
1315
|
+
|
|
1316
|
+
elif isinstance(node.base, Control):
|
|
1317
|
+
sl = node.base._slice
|
|
1318
|
+
if sl is None:
|
|
1319
|
+
raise ValueError(f"Control {node.base.name!r} has no slice assigned")
|
|
1320
|
+
|
|
1321
|
+
def control_node_fn(x, u, node_param, params):
|
|
1322
|
+
return u[fixed_idx, sl]
|
|
1323
|
+
|
|
1324
|
+
return control_node_fn
|
|
1325
|
+
|
|
1326
|
+
else:
|
|
1327
|
+
# Compound expression (e.g., position[0].at(5))
|
|
1328
|
+
base_fn = self.lower(node.base)
|
|
1329
|
+
|
|
1330
|
+
def compound_node_fn(x, u, node_param, params):
|
|
1331
|
+
# Extract single-node slices and evaluate base expression
|
|
1332
|
+
x_single = x[fixed_idx] if len(x.shape) > 1 else x
|
|
1333
|
+
u_single = u[fixed_idx] if len(u.shape) > 1 else u
|
|
1334
|
+
return base_fn(x_single, u_single, fixed_idx, params)
|
|
1335
|
+
|
|
1336
|
+
return compound_node_fn
|
|
1337
|
+
|
|
1338
|
+
@visitor(CrossNodeConstraint)
|
|
1339
|
+
def _visit_cross_node_constraint(self, node: CrossNodeConstraint):
|
|
1340
|
+
"""Lower CrossNodeConstraint to trajectory-level function.
|
|
1341
|
+
|
|
1342
|
+
CrossNodeConstraint wraps constraints that reference multiple trajectory
|
|
1343
|
+
nodes via NodeReference (e.g., rate limits like x.at(k) - x.at(k-1) <= r).
|
|
1344
|
+
|
|
1345
|
+
Unlike regular nodal constraints which have signature (x, u, node, params)
|
|
1346
|
+
and are vmapped across nodes, cross-node constraints operate on full
|
|
1347
|
+
trajectory arrays and return a scalar residual.
|
|
1348
|
+
|
|
1349
|
+
Args:
|
|
1350
|
+
node: CrossNodeConstraint expression wrapping the inner constraint
|
|
1351
|
+
|
|
1352
|
+
Returns:
|
|
1353
|
+
Function with signature (X, U, params) -> scalar residual
|
|
1354
|
+
- X: Full state trajectory, shape (N, n_x)
|
|
1355
|
+
- U: Full control trajectory, shape (N, n_u)
|
|
1356
|
+
- params: Dictionary of problem parameters
|
|
1357
|
+
- Returns: Scalar constraint residual (g <= 0 convention)
|
|
1358
|
+
|
|
1359
|
+
Note:
|
|
1360
|
+
The inner constraint is lowered first (producing a function with the
|
|
1361
|
+
standard (x, u, node, params) signature), then wrapped to provide the
|
|
1362
|
+
trajectory-level (X, U, params) signature. The `node` parameter is
|
|
1363
|
+
unused since NodeReference nodes have fixed indices baked in.
|
|
1364
|
+
|
|
1365
|
+
Example:
|
|
1366
|
+
For constraint: position.at(5) - position.at(4) <= max_step
|
|
1367
|
+
|
|
1368
|
+
The lowered function evaluates:
|
|
1369
|
+
X[5, pos_slice] - X[4, pos_slice] - max_step
|
|
1370
|
+
|
|
1371
|
+
And returns a scalar residual.
|
|
1372
|
+
"""
|
|
1373
|
+
# Lower the inner constraint expression
|
|
1374
|
+
inner_fn = self.lower(node.constraint)
|
|
1375
|
+
|
|
1376
|
+
# Wrap to provide trajectory-level signature
|
|
1377
|
+
# The `node` parameter is unused for cross-node constraints since
|
|
1378
|
+
# NodeReference nodes have fixed indices baked in at construction time
|
|
1379
|
+
def trajectory_constraint(X, U, params):
|
|
1380
|
+
return inner_fn(X, U, 0, params)
|
|
1381
|
+
|
|
1382
|
+
return trajectory_constraint
|