openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,720 @@
|
|
|
1
|
+
"""Core symbolic expression system for trajectory optimization.
|
|
2
|
+
|
|
3
|
+
This module provides the foundation for openscvx's symbolic expression framework,
|
|
4
|
+
implementing an Abstract Syntax Tree (AST) representation for mathematical expressions
|
|
5
|
+
used in optimization problems. The expression system enables:
|
|
6
|
+
|
|
7
|
+
- Declarative problem specification: Write optimization problems using familiar
|
|
8
|
+
mathematical notation with operator overloading (+, -, *, /, @, **, etc.)
|
|
9
|
+
- Automatic differentiation: Expressions are automatically differentiated during
|
|
10
|
+
compilation to solver-specific formats
|
|
11
|
+
- Shape checking: Static validation of tensor dimensions before optimization
|
|
12
|
+
- Canonicalization: Algebraic simplification for more efficient compilation
|
|
13
|
+
- Multiple backends: Expressions can be compiled to CVXPy, JAX, or custom solvers
|
|
14
|
+
|
|
15
|
+
Architecture:
|
|
16
|
+
The expression system is built around an AST where each node is an `Expr` subclass:
|
|
17
|
+
|
|
18
|
+
- Leaf nodes: `Parameter`, `Variable`, `State`, `Control` - symbolic values
|
|
19
|
+
- Arithmetic operations: `Add`, `Sub`, `Mul`, `Div`, `MatMul`, `Power`, `Neg`
|
|
20
|
+
- Array operations: `Index`, `Concat`, `Stack`, `Hstack`, `Vstack`
|
|
21
|
+
- Linear algebra: `Transpose`, `Diag`, `Sum`, `Norm`
|
|
22
|
+
- Constraints: `Equality`, `Inequality`
|
|
23
|
+
- Functions: `Sin`, `Cos`, `Exp`, `Log`, `Sqrt`, etc.
|
|
24
|
+
|
|
25
|
+
Each expression node implements:
|
|
26
|
+
|
|
27
|
+
- `children()`: Returns child expressions in the AST
|
|
28
|
+
- `canonicalize()`: Returns a simplified/normalized version
|
|
29
|
+
- `check_shape()`: Validates and returns the output shape
|
|
30
|
+
|
|
31
|
+
Example:
|
|
32
|
+
Creating symbolic variables and expressions::
|
|
33
|
+
|
|
34
|
+
import openscvx as ox
|
|
35
|
+
|
|
36
|
+
# Define symbolic variables
|
|
37
|
+
x = ox.State("x", shape=(3,))
|
|
38
|
+
A = ox.Parameter("A", shape=(3, 3), value=np.eye(3))
|
|
39
|
+
|
|
40
|
+
# Build expressions using natural syntax
|
|
41
|
+
expr = A @ x + 5
|
|
42
|
+
constraint = ox.Norm(x) <= 1.0
|
|
43
|
+
|
|
44
|
+
# Expressions form an AST
|
|
45
|
+
print(expr.pretty()) # Visualize the tree structure
|
|
46
|
+
|
|
47
|
+
Shape checking with automatic validation::
|
|
48
|
+
|
|
49
|
+
x = ox.State("x", shape=(3,))
|
|
50
|
+
y = ox.State("y", shape=(4,))
|
|
51
|
+
|
|
52
|
+
# This will raise ValueError during shape checking
|
|
53
|
+
try:
|
|
54
|
+
expr = x + y # Shapes (3,) and (4,) not broadcastable
|
|
55
|
+
expr.check_shape()
|
|
56
|
+
except ValueError as e:
|
|
57
|
+
print(f"Shape error: {e}")
|
|
58
|
+
|
|
59
|
+
Algebraic canonicalization::
|
|
60
|
+
|
|
61
|
+
x = ox.State("x", shape=(3,))
|
|
62
|
+
expr = x + 0 + (1 * x)
|
|
63
|
+
canonical = expr.canonicalize() # Simplifies to: x + x
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
import hashlib
|
|
67
|
+
import struct
|
|
68
|
+
from typing import Callable, Tuple, Union
|
|
69
|
+
|
|
70
|
+
import numpy as np
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class Expr:
|
|
74
|
+
"""Base class for symbolic expressions in optimization problems.
|
|
75
|
+
|
|
76
|
+
Expr is the foundation of the symbolic expression system in openscvx. It represents
|
|
77
|
+
nodes in an abstract syntax tree (AST) for mathematical expressions. Expressions
|
|
78
|
+
support:
|
|
79
|
+
|
|
80
|
+
- Arithmetic operations: +, -, *, /, @, **
|
|
81
|
+
- Comparison operations: ==, <=, >=
|
|
82
|
+
- Indexing and slicing: []
|
|
83
|
+
- Transposition: .T property
|
|
84
|
+
- Shape checking and validation
|
|
85
|
+
- Canonicalization (algebraic simplification)
|
|
86
|
+
|
|
87
|
+
All Expr subclasses implement a tree structure where each node can have child
|
|
88
|
+
expressions accessed via the children() method.
|
|
89
|
+
|
|
90
|
+
Attributes:
|
|
91
|
+
__array_priority__: Priority for operations with numpy arrays (set to 1000)
|
|
92
|
+
|
|
93
|
+
Note:
|
|
94
|
+
When used in operations with numpy arrays, Expr objects take precedence,
|
|
95
|
+
allowing symbolic expressions to wrap numeric values automatically.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
# Give Expr objects higher priority than numpy arrays in operations
|
|
99
|
+
__array_priority__ = 1000
|
|
100
|
+
|
|
101
|
+
def __le__(self, other):
|
|
102
|
+
from .constraint import Inequality
|
|
103
|
+
|
|
104
|
+
return Inequality(self, to_expr(other))
|
|
105
|
+
|
|
106
|
+
def __ge__(self, other):
|
|
107
|
+
from .constraint import Inequality
|
|
108
|
+
|
|
109
|
+
return Inequality(to_expr(other), self)
|
|
110
|
+
|
|
111
|
+
def __eq__(self, other):
|
|
112
|
+
from .constraint import Equality
|
|
113
|
+
|
|
114
|
+
return Equality(self, to_expr(other))
|
|
115
|
+
|
|
116
|
+
def __add__(self, other):
|
|
117
|
+
from .arithmetic import Add
|
|
118
|
+
|
|
119
|
+
return Add(self, to_expr(other))
|
|
120
|
+
|
|
121
|
+
def __radd__(self, other):
|
|
122
|
+
from .arithmetic import Add
|
|
123
|
+
|
|
124
|
+
return Add(to_expr(other), self)
|
|
125
|
+
|
|
126
|
+
def __sub__(self, other):
|
|
127
|
+
from .arithmetic import Sub
|
|
128
|
+
|
|
129
|
+
return Sub(self, to_expr(other))
|
|
130
|
+
|
|
131
|
+
def __rsub__(self, other):
|
|
132
|
+
# e.g. 5 - a ⇒ Sub(Constant(5), a)
|
|
133
|
+
from .arithmetic import Sub
|
|
134
|
+
|
|
135
|
+
return Sub(to_expr(other), self)
|
|
136
|
+
|
|
137
|
+
def __truediv__(self, other):
|
|
138
|
+
from .arithmetic import Div
|
|
139
|
+
|
|
140
|
+
return Div(self, to_expr(other))
|
|
141
|
+
|
|
142
|
+
def __rtruediv__(self, other):
|
|
143
|
+
# e.g. 10 / a
|
|
144
|
+
from .arithmetic import Div
|
|
145
|
+
|
|
146
|
+
return Div(to_expr(other), self)
|
|
147
|
+
|
|
148
|
+
def __mul__(self, other):
|
|
149
|
+
from .arithmetic import Mul
|
|
150
|
+
|
|
151
|
+
return Mul(self, to_expr(other))
|
|
152
|
+
|
|
153
|
+
def __rmul__(self, other):
|
|
154
|
+
from .arithmetic import Mul
|
|
155
|
+
|
|
156
|
+
return Mul(to_expr(other), self)
|
|
157
|
+
|
|
158
|
+
def __matmul__(self, other):
|
|
159
|
+
from .arithmetic import MatMul
|
|
160
|
+
|
|
161
|
+
return MatMul(self, to_expr(other))
|
|
162
|
+
|
|
163
|
+
def __rmatmul__(self, other):
|
|
164
|
+
from .arithmetic import MatMul
|
|
165
|
+
|
|
166
|
+
return MatMul(to_expr(other), self)
|
|
167
|
+
|
|
168
|
+
def __rle__(self, other):
|
|
169
|
+
# other <= self => Inequality(other, self)
|
|
170
|
+
from .constraint import Inequality
|
|
171
|
+
|
|
172
|
+
return Inequality(to_expr(other), self)
|
|
173
|
+
|
|
174
|
+
def __rge__(self, other):
|
|
175
|
+
# other >= self => Inequality(self, other)
|
|
176
|
+
from .constraint import Inequality
|
|
177
|
+
|
|
178
|
+
return Inequality(self, to_expr(other))
|
|
179
|
+
|
|
180
|
+
def __req__(self, other):
|
|
181
|
+
# other == self => Equality(other, self)
|
|
182
|
+
from .constraint import Equality
|
|
183
|
+
|
|
184
|
+
return Equality(to_expr(other), self)
|
|
185
|
+
|
|
186
|
+
def __neg__(self):
|
|
187
|
+
from .arithmetic import Neg
|
|
188
|
+
|
|
189
|
+
return Neg(self)
|
|
190
|
+
|
|
191
|
+
def __pow__(self, other):
|
|
192
|
+
from .arithmetic import Power
|
|
193
|
+
|
|
194
|
+
return Power(self, to_expr(other))
|
|
195
|
+
|
|
196
|
+
def __rpow__(self, other):
|
|
197
|
+
from .arithmetic import Power
|
|
198
|
+
|
|
199
|
+
return Power(to_expr(other), self)
|
|
200
|
+
|
|
201
|
+
def __getitem__(self, idx):
|
|
202
|
+
from .array import Index
|
|
203
|
+
|
|
204
|
+
return Index(self, idx)
|
|
205
|
+
|
|
206
|
+
@property
|
|
207
|
+
def T(self):
|
|
208
|
+
"""Transpose property for matrix expressions.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Transpose: A Transpose expression wrapping this expression
|
|
212
|
+
|
|
213
|
+
Example:
|
|
214
|
+
Create a transpose:
|
|
215
|
+
|
|
216
|
+
A = ox.State("A", shape=(3, 4))
|
|
217
|
+
A_T = A.T # Creates Transpose(A), result shape (4, 3)
|
|
218
|
+
"""
|
|
219
|
+
from .linalg import Transpose
|
|
220
|
+
|
|
221
|
+
return Transpose(self)
|
|
222
|
+
|
|
223
|
+
def at(self, k: int) -> "NodeReference":
|
|
224
|
+
"""Reference this expression at a specific trajectory node.
|
|
225
|
+
|
|
226
|
+
This method enables inter-node constraints where you can reference
|
|
227
|
+
the value of an expression at different time steps. Common patterns
|
|
228
|
+
include rate limits and multi-step dependencies.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
k: Absolute node index (integer) in the trajectory.
|
|
232
|
+
Can be positive (0, 1, 2, ...) or negative (-1 for last node).
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
NodeReference: An expression representing this expression at node k
|
|
236
|
+
|
|
237
|
+
Example:
|
|
238
|
+
Rate limit constraint (applied across trajectory using a loop):
|
|
239
|
+
|
|
240
|
+
position = State("pos", shape=(3,))
|
|
241
|
+
|
|
242
|
+
# Create rate limit for each node
|
|
243
|
+
constraints = [
|
|
244
|
+
(ox.linalg.Norm(position.at(k) - position.at(k-1)) <= 0.1).at([k])
|
|
245
|
+
for k in range(1, N)
|
|
246
|
+
]
|
|
247
|
+
|
|
248
|
+
Multi-step dependency:
|
|
249
|
+
|
|
250
|
+
state = State("x", shape=(1,))
|
|
251
|
+
|
|
252
|
+
# Fibonacci-like recurrence
|
|
253
|
+
constraints = [
|
|
254
|
+
(state.at(k) == state.at(k-1) + state.at(k-2)).at([k])
|
|
255
|
+
for k in range(2, N)
|
|
256
|
+
]
|
|
257
|
+
|
|
258
|
+
Performance Note:
|
|
259
|
+
Cross-node constraints use dense Jacobian storage which can be memory-intensive
|
|
260
|
+
for large N (>100 nodes). See LoweredCrossNodeConstraint documentation for
|
|
261
|
+
details on memory usage and future sparse Jacobian support.
|
|
262
|
+
"""
|
|
263
|
+
return NodeReference(self, k)
|
|
264
|
+
|
|
265
|
+
def children(self):
|
|
266
|
+
"""Return the child expressions of this node.
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
list: List of child Expr objects. Empty list for leaf nodes.
|
|
270
|
+
"""
|
|
271
|
+
return []
|
|
272
|
+
|
|
273
|
+
def canonicalize(self) -> "Expr":
|
|
274
|
+
"""
|
|
275
|
+
Return a canonical (simplified) form of this expression.
|
|
276
|
+
|
|
277
|
+
Canonicalization performs algebraic simplifications such as:
|
|
278
|
+
- Constant folding (e.g., 2 + 3 → 5)
|
|
279
|
+
- Identity elimination (e.g., x + 0 → x, x * 1 → x)
|
|
280
|
+
- Flattening nested operations (e.g., Add(Add(a, b), c) → Add(a, b, c))
|
|
281
|
+
- Algebraic rewrites (e.g., constraints to standard form)
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
Expr: A canonical version of this expression
|
|
285
|
+
|
|
286
|
+
Raises:
|
|
287
|
+
NotImplementedError: If canonicalization is not implemented for this node type
|
|
288
|
+
"""
|
|
289
|
+
raise NotImplementedError(f"canonicalize() not implemented for {self.__class__.__name__}")
|
|
290
|
+
|
|
291
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
292
|
+
"""
|
|
293
|
+
Compute and validate the shape of this expression.
|
|
294
|
+
|
|
295
|
+
This method:
|
|
296
|
+
1. Recursively checks shapes of all child expressions
|
|
297
|
+
2. Validates that operations are shape-compatible (e.g., broadcasting rules)
|
|
298
|
+
3. Returns the output shape of this expression
|
|
299
|
+
|
|
300
|
+
For example:
|
|
301
|
+
- A Parameter with shape (3, 4) returns (3, 4)
|
|
302
|
+
- MatMul of (3, 4) @ (4, 5) returns (3, 5)
|
|
303
|
+
- Sum of any shape returns () (scalar)
|
|
304
|
+
- Add broadcasts shapes like NumPy
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
tuple: The shape of this expression as a tuple of integers.
|
|
308
|
+
Empty tuple () represents a scalar.
|
|
309
|
+
|
|
310
|
+
Raises:
|
|
311
|
+
NotImplementedError: If shape checking is not implemented for this node type
|
|
312
|
+
ValueError: If the expression has invalid shapes (e.g., incompatible dimensions)
|
|
313
|
+
"""
|
|
314
|
+
raise NotImplementedError(f"check_shape() not implemented for {self.__class__.__name__}")
|
|
315
|
+
|
|
316
|
+
def pretty(self, indent=0):
|
|
317
|
+
"""Generate a pretty-printed string representation of the expression tree.
|
|
318
|
+
|
|
319
|
+
Creates an indented, hierarchical view of the expression tree structure,
|
|
320
|
+
useful for debugging and visualization.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
indent: Current indentation level (default: 0)
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
str: Multi-line string representation of the expression tree
|
|
327
|
+
|
|
328
|
+
Example:
|
|
329
|
+
Pretty print an expression:
|
|
330
|
+
|
|
331
|
+
expr = (x + y) * z
|
|
332
|
+
print(expr.pretty())
|
|
333
|
+
# Mul
|
|
334
|
+
# Add
|
|
335
|
+
# State
|
|
336
|
+
# State
|
|
337
|
+
# State
|
|
338
|
+
"""
|
|
339
|
+
pad = " " * indent
|
|
340
|
+
pad = " " * indent
|
|
341
|
+
lines = [f"{pad}{self.__class__.__name__}"]
|
|
342
|
+
for child in self.children():
|
|
343
|
+
lines.append(child.pretty(indent + 1))
|
|
344
|
+
return "\n".join(lines)
|
|
345
|
+
|
|
346
|
+
def _hash_into(self, hasher: "hashlib._Hash") -> None:
|
|
347
|
+
"""Contribute this expression's structural identity to a hash.
|
|
348
|
+
|
|
349
|
+
This method is used to compute a structural hash of the expression tree
|
|
350
|
+
that is name-invariant (same structure = same hash regardless of variable names).
|
|
351
|
+
|
|
352
|
+
The default implementation hashes the class name and recursively hashes all
|
|
353
|
+
children. Subclasses with additional attributes (like Norm.ord, Index.index)
|
|
354
|
+
should override this to include those attributes.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
hasher: A hashlib hash object to update
|
|
358
|
+
"""
|
|
359
|
+
# Hash the class name to distinguish different node types
|
|
360
|
+
hasher.update(self.__class__.__name__.encode())
|
|
361
|
+
# Recursively hash all children
|
|
362
|
+
for child in self.children():
|
|
363
|
+
child._hash_into(hasher)
|
|
364
|
+
|
|
365
|
+
def structural_hash(self) -> bytes:
|
|
366
|
+
"""Compute a structural hash of this expression.
|
|
367
|
+
|
|
368
|
+
Returns a hash that depends only on the mathematical structure of the
|
|
369
|
+
expression, not on variable names. Two expressions that are structurally
|
|
370
|
+
equivalent (same operations, same variable positions) will have the same hash.
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
bytes: SHA-256 digest of the expression structure
|
|
374
|
+
"""
|
|
375
|
+
hasher = hashlib.sha256()
|
|
376
|
+
self._hash_into(hasher)
|
|
377
|
+
return hasher.digest()
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
class Leaf(Expr):
|
|
381
|
+
"""
|
|
382
|
+
Base class for leaf nodes (terminal expressions) in the symbolic expression tree.
|
|
383
|
+
|
|
384
|
+
Leaf nodes represent named symbolic variables that don't have child expressions.
|
|
385
|
+
This includes Parameters, Variables, States, and Controls.
|
|
386
|
+
|
|
387
|
+
Attributes:
|
|
388
|
+
name (str): Name identifier for the leaf node
|
|
389
|
+
_shape (tuple): Shape of the leaf node
|
|
390
|
+
"""
|
|
391
|
+
|
|
392
|
+
def __init__(self, name: str, shape: tuple = ()):
|
|
393
|
+
"""Initialize a Leaf node.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
name (str): Name identifier for the leaf node
|
|
397
|
+
shape (tuple): Shape of the leaf node
|
|
398
|
+
"""
|
|
399
|
+
super().__init__()
|
|
400
|
+
self.name = name
|
|
401
|
+
self._shape = shape
|
|
402
|
+
|
|
403
|
+
@property
|
|
404
|
+
def shape(self):
|
|
405
|
+
"""Get the shape of the leaf node.
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
tuple: Shape of the leaf node
|
|
409
|
+
"""
|
|
410
|
+
return self._shape
|
|
411
|
+
|
|
412
|
+
def children(self):
|
|
413
|
+
"""Leaf nodes have no children.
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
list: Empty list since leaf nodes are terminal
|
|
417
|
+
"""
|
|
418
|
+
return []
|
|
419
|
+
|
|
420
|
+
def canonicalize(self) -> "Expr":
|
|
421
|
+
"""Leaf nodes are already in canonical form.
|
|
422
|
+
|
|
423
|
+
Returns:
|
|
424
|
+
Expr: Returns self since leaf nodes are already canonical
|
|
425
|
+
"""
|
|
426
|
+
return self
|
|
427
|
+
|
|
428
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
429
|
+
"""Return the shape of this leaf node.
|
|
430
|
+
|
|
431
|
+
Returns:
|
|
432
|
+
tuple: The shape of the leaf node
|
|
433
|
+
"""
|
|
434
|
+
return self._shape
|
|
435
|
+
|
|
436
|
+
def _hash_into(self, hasher: "hashlib._Hash") -> None:
|
|
437
|
+
"""Hash leaf node by class name and shape.
|
|
438
|
+
|
|
439
|
+
This base implementation hashes the class name and shape. Subclasses
|
|
440
|
+
like Variable and Parameter override this to add their specific
|
|
441
|
+
canonical identifiers (_slice for Variables, value for Parameters).
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
hasher: A hashlib hash object to update
|
|
445
|
+
"""
|
|
446
|
+
hasher.update(self.__class__.__name__.encode())
|
|
447
|
+
hasher.update(str(self._shape).encode())
|
|
448
|
+
|
|
449
|
+
def __repr__(self):
|
|
450
|
+
"""String representation of the leaf node.
|
|
451
|
+
|
|
452
|
+
Returns:
|
|
453
|
+
str: A string describing the leaf node
|
|
454
|
+
"""
|
|
455
|
+
return f"{self.__class__.__name__}('{self.name}', shape={self.shape})"
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
class Parameter(Leaf):
|
|
459
|
+
"""Parameter that can be changed at runtime without recompilation.
|
|
460
|
+
|
|
461
|
+
Parameters are symbolic variables with initial values that can be updated
|
|
462
|
+
through the problem's parameter dictionary. They allow for efficient
|
|
463
|
+
parameter sweeps without needing to recompile the optimization problem.
|
|
464
|
+
|
|
465
|
+
Example:
|
|
466
|
+
obs_center = ox.Parameter("obs_center", shape=(3,), value=np.array([1.0, 0.0, 0.0]))
|
|
467
|
+
# Later: problem.parameters["obs_center"] = new_value
|
|
468
|
+
"""
|
|
469
|
+
|
|
470
|
+
def __init__(self, name: str, shape: tuple = (), value=None):
|
|
471
|
+
"""Initialize a Parameter node.
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
name (str): Name identifier for the parameter
|
|
475
|
+
shape (tuple): Shape of the parameter (default: scalar)
|
|
476
|
+
value: Initial value for the parameter (required)
|
|
477
|
+
"""
|
|
478
|
+
super().__init__(name, shape)
|
|
479
|
+
if value is None:
|
|
480
|
+
raise ValueError(f"Parameter '{name}' requires an initial value")
|
|
481
|
+
self.value = np.asarray(value, dtype=float)
|
|
482
|
+
|
|
483
|
+
def _hash_into(self, hasher: "hashlib._Hash") -> None:
|
|
484
|
+
"""Hash Parameter by its shape only (value-invariant).
|
|
485
|
+
|
|
486
|
+
Parameters are hashed by shape only, not by value. This allows the same
|
|
487
|
+
compiled solver to be reused across parameter sweeps - only the structure
|
|
488
|
+
matters for compilation, not the actual values.
|
|
489
|
+
|
|
490
|
+
Args:
|
|
491
|
+
hasher: A hashlib hash object to update
|
|
492
|
+
"""
|
|
493
|
+
hasher.update(b"Parameter")
|
|
494
|
+
hasher.update(str(self._shape).encode())
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def to_expr(x: Union[Expr, float, int, np.ndarray]) -> Expr:
|
|
498
|
+
"""Convert a value to an Expr if it is not already one.
|
|
499
|
+
|
|
500
|
+
This is a convenience function that wraps numeric values and arrays as Constant
|
|
501
|
+
expressions, while leaving Expr instances unchanged. Used internally by operators
|
|
502
|
+
to ensure operands are proper Expr objects.
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
x: Value to convert - can be an Expr, numeric scalar, or numpy array
|
|
506
|
+
|
|
507
|
+
Returns:
|
|
508
|
+
The input if it's already an Expr, otherwise a Constant wrapping the value
|
|
509
|
+
"""
|
|
510
|
+
return x if isinstance(x, Expr) else Constant(np.array(x))
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def traverse(expr: Expr, visit: Callable[[Expr], None]):
|
|
514
|
+
"""Depth-first traversal of an expression tree.
|
|
515
|
+
|
|
516
|
+
Visits each node in the expression tree by applying the visit function to the
|
|
517
|
+
current node, then recursively visiting all children.
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
expr: Root expression node to start traversal from
|
|
521
|
+
visit: Callback function applied to each node during traversal
|
|
522
|
+
"""
|
|
523
|
+
visit(expr)
|
|
524
|
+
for child in expr.children():
|
|
525
|
+
traverse(child, visit)
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
class Constant(Expr):
|
|
529
|
+
"""Constant value expression.
|
|
530
|
+
|
|
531
|
+
Represents a constant numeric value in the expression tree. Constants are
|
|
532
|
+
automatically normalized (squeezed) upon construction to ensure consistency.
|
|
533
|
+
|
|
534
|
+
Attributes:
|
|
535
|
+
value: The numpy array representing the constant value (squeezed)
|
|
536
|
+
|
|
537
|
+
Example:
|
|
538
|
+
Define constants:
|
|
539
|
+
|
|
540
|
+
c1 = Constant(5.0) # Scalar constant
|
|
541
|
+
c2 = Constant([1, 2, 3]) # Vector constant
|
|
542
|
+
c3 = to_expr(10) # Also creates a Constant
|
|
543
|
+
"""
|
|
544
|
+
|
|
545
|
+
def __init__(self, value: np.ndarray):
|
|
546
|
+
"""Initialize a constant expression.
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
value: Numeric value or numpy array to wrap as a constant.
|
|
550
|
+
Will be converted to numpy array and squeezed.
|
|
551
|
+
"""
|
|
552
|
+
# Normalize immediately upon construction to ensure consistency
|
|
553
|
+
# This ensures Constant(5.0) and Constant([5.0]) create identical objects
|
|
554
|
+
if not isinstance(value, np.ndarray):
|
|
555
|
+
value = np.array(value, dtype=float)
|
|
556
|
+
self.value = np.squeeze(value)
|
|
557
|
+
|
|
558
|
+
def canonicalize(self) -> "Expr":
|
|
559
|
+
"""Constants are already in canonical form.
|
|
560
|
+
|
|
561
|
+
Returns:
|
|
562
|
+
Expr: Returns self since constants are already canonical
|
|
563
|
+
"""
|
|
564
|
+
return self
|
|
565
|
+
|
|
566
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
567
|
+
"""Return the shape of this constant's value.
|
|
568
|
+
|
|
569
|
+
Returns:
|
|
570
|
+
tuple: The shape of the constant's numpy array value
|
|
571
|
+
"""
|
|
572
|
+
# Verify the invariant: constants should already be squeezed during construction
|
|
573
|
+
original_shape = self.value.shape
|
|
574
|
+
squeezed_shape = np.squeeze(self.value).shape
|
|
575
|
+
if original_shape != squeezed_shape:
|
|
576
|
+
raise ValueError(
|
|
577
|
+
f"Constant not properly normalized: has shape {original_shape} "
|
|
578
|
+
"but should have shape {squeezed_shape}. "
|
|
579
|
+
"Constants should be squeezed during construction."
|
|
580
|
+
)
|
|
581
|
+
return self.value.shape
|
|
582
|
+
|
|
583
|
+
def _hash_into(self, hasher: "hashlib._Hash") -> None:
|
|
584
|
+
"""Hash constant by its value.
|
|
585
|
+
|
|
586
|
+
Constants are hashed by their actual numeric value, ensuring that
|
|
587
|
+
expressions with the same constant values produce the same hash.
|
|
588
|
+
|
|
589
|
+
Args:
|
|
590
|
+
hasher: A hashlib hash object to update
|
|
591
|
+
"""
|
|
592
|
+
hasher.update(b"Constant")
|
|
593
|
+
hasher.update(str(self.value.shape).encode())
|
|
594
|
+
hasher.update(self.value.tobytes())
|
|
595
|
+
|
|
596
|
+
def __repr__(self):
|
|
597
|
+
# Show clean representation - always show as Python values, not numpy arrays
|
|
598
|
+
if self.value.ndim == 0:
|
|
599
|
+
# Scalar: show as plain number
|
|
600
|
+
return f"Const({self.value.item()!r})"
|
|
601
|
+
else:
|
|
602
|
+
# Array: show as Python list for readability
|
|
603
|
+
return f"Const({self.value.tolist()!r})"
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
class NodeReference(Expr):
|
|
607
|
+
"""Reference to a variable at a specific trajectory node.
|
|
608
|
+
|
|
609
|
+
NodeReference enables inter-node constraints by allowing you to reference
|
|
610
|
+
the value of a state or control variable at a specific discrete time point
|
|
611
|
+
(node) in the trajectory. This is essential for expressing temporal relationships
|
|
612
|
+
such as:
|
|
613
|
+
|
|
614
|
+
- Rate limits and smoothness constraints
|
|
615
|
+
- Multi-step dependencies and recurrence relations
|
|
616
|
+
- Constraints coupling specific nodes
|
|
617
|
+
|
|
618
|
+
Attributes:
|
|
619
|
+
base: The expression (typically a Leaf like State or Control) being referenced
|
|
620
|
+
node_idx: Trajectory node index (integer, can be negative for end-indexing)
|
|
621
|
+
|
|
622
|
+
Example:
|
|
623
|
+
Rate limit across trajectory:
|
|
624
|
+
|
|
625
|
+
position = State("pos", shape=(3,))
|
|
626
|
+
|
|
627
|
+
# Create rate limit constraints for all nodes
|
|
628
|
+
constraints = [
|
|
629
|
+
(ox.linalg.Norm(position.at(k) - position.at(k-1)) <= 0.1).at([k])
|
|
630
|
+
for k in range(1, N)
|
|
631
|
+
]
|
|
632
|
+
|
|
633
|
+
Multi-step dependency:
|
|
634
|
+
|
|
635
|
+
state = State("x", shape=(1,))
|
|
636
|
+
|
|
637
|
+
# Fibonacci-like recurrence at each node
|
|
638
|
+
constraints = [
|
|
639
|
+
(state.at(k) == state.at(k-1) + state.at(k-2)).at([k])
|
|
640
|
+
for k in range(2, N)
|
|
641
|
+
]
|
|
642
|
+
|
|
643
|
+
Coupling specific nodes:
|
|
644
|
+
|
|
645
|
+
# Constrain distance between nodes 5 and 10
|
|
646
|
+
coupling = (position.at(10) - position.at(5) <= threshold).at([10])
|
|
647
|
+
|
|
648
|
+
Performance Note:
|
|
649
|
+
Cross-node constraints use dense Jacobian storage. For details on memory
|
|
650
|
+
usage and performance implications, see LoweredCrossNodeConstraint documentation.
|
|
651
|
+
|
|
652
|
+
Note:
|
|
653
|
+
NodeReference is typically created via the `.at(k)` method on expressions
|
|
654
|
+
rather than constructed directly.
|
|
655
|
+
"""
|
|
656
|
+
|
|
657
|
+
def __init__(self, base: Expr, node_idx: int):
|
|
658
|
+
"""Initialize a NodeReference.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
base: Expression to reference at a specific node (typically a Leaf)
|
|
662
|
+
node_idx: Absolute trajectory node index (integer)
|
|
663
|
+
Supports negative indexing (e.g., -1 for last node)
|
|
664
|
+
|
|
665
|
+
Raises:
|
|
666
|
+
TypeError: If node_idx is not an integer
|
|
667
|
+
"""
|
|
668
|
+
if not isinstance(node_idx, int):
|
|
669
|
+
raise TypeError(f"Node index must be an integer, got {type(node_idx).__name__}")
|
|
670
|
+
|
|
671
|
+
self.node_idx = node_idx
|
|
672
|
+
self.base = base
|
|
673
|
+
|
|
674
|
+
def children(self):
|
|
675
|
+
"""Return the base expression as the only child.
|
|
676
|
+
|
|
677
|
+
Returns:
|
|
678
|
+
list: Single-element list containing the base expression
|
|
679
|
+
"""
|
|
680
|
+
return [self.base]
|
|
681
|
+
|
|
682
|
+
def canonicalize(self) -> "Expr":
|
|
683
|
+
"""Canonicalize by canonicalizing the base expression.
|
|
684
|
+
|
|
685
|
+
Returns:
|
|
686
|
+
NodeReference: A new NodeReference with canonicalized base
|
|
687
|
+
"""
|
|
688
|
+
canon_base = self.base.canonicalize()
|
|
689
|
+
return NodeReference(canon_base, self.node_idx)
|
|
690
|
+
|
|
691
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
692
|
+
"""Return the shape of the base expression.
|
|
693
|
+
|
|
694
|
+
NodeReference doesn't change the shape of the underlying expression,
|
|
695
|
+
it just references it at a specific time point.
|
|
696
|
+
|
|
697
|
+
Returns:
|
|
698
|
+
tuple: The shape of the base expression
|
|
699
|
+
"""
|
|
700
|
+
return self.base.check_shape()
|
|
701
|
+
|
|
702
|
+
def _hash_into(self, hasher: "hashlib._Hash") -> None:
|
|
703
|
+
"""Hash NodeReference including its node index.
|
|
704
|
+
|
|
705
|
+
Args:
|
|
706
|
+
hasher: A hashlib hash object to update
|
|
707
|
+
"""
|
|
708
|
+
hasher.update(b"NodeReference")
|
|
709
|
+
# Hash the node index (signed int)
|
|
710
|
+
hasher.update(struct.pack(">i", self.node_idx))
|
|
711
|
+
# Hash the base expression
|
|
712
|
+
self.base._hash_into(hasher)
|
|
713
|
+
|
|
714
|
+
def __repr__(self):
|
|
715
|
+
"""String representation of the NodeReference.
|
|
716
|
+
|
|
717
|
+
Returns:
|
|
718
|
+
str: String showing the base expression and node index
|
|
719
|
+
"""
|
|
720
|
+
return f"{self.base!r}.at({self.node_idx})"
|