openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,720 @@
1
+ """Core symbolic expression system for trajectory optimization.
2
+
3
+ This module provides the foundation for openscvx's symbolic expression framework,
4
+ implementing an Abstract Syntax Tree (AST) representation for mathematical expressions
5
+ used in optimization problems. The expression system enables:
6
+
7
+ - Declarative problem specification: Write optimization problems using familiar
8
+ mathematical notation with operator overloading (+, -, *, /, @, **, etc.)
9
+ - Automatic differentiation: Expressions are automatically differentiated during
10
+ compilation to solver-specific formats
11
+ - Shape checking: Static validation of tensor dimensions before optimization
12
+ - Canonicalization: Algebraic simplification for more efficient compilation
13
+ - Multiple backends: Expressions can be compiled to CVXPy, JAX, or custom solvers
14
+
15
+ Architecture:
16
+ The expression system is built around an AST where each node is an `Expr` subclass:
17
+
18
+ - Leaf nodes: `Parameter`, `Variable`, `State`, `Control` - symbolic values
19
+ - Arithmetic operations: `Add`, `Sub`, `Mul`, `Div`, `MatMul`, `Power`, `Neg`
20
+ - Array operations: `Index`, `Concat`, `Stack`, `Hstack`, `Vstack`
21
+ - Linear algebra: `Transpose`, `Diag`, `Sum`, `Norm`
22
+ - Constraints: `Equality`, `Inequality`
23
+ - Functions: `Sin`, `Cos`, `Exp`, `Log`, `Sqrt`, etc.
24
+
25
+ Each expression node implements:
26
+
27
+ - `children()`: Returns child expressions in the AST
28
+ - `canonicalize()`: Returns a simplified/normalized version
29
+ - `check_shape()`: Validates and returns the output shape
30
+
31
+ Example:
32
+ Creating symbolic variables and expressions::
33
+
34
+ import openscvx as ox
35
+
36
+ # Define symbolic variables
37
+ x = ox.State("x", shape=(3,))
38
+ A = ox.Parameter("A", shape=(3, 3), value=np.eye(3))
39
+
40
+ # Build expressions using natural syntax
41
+ expr = A @ x + 5
42
+ constraint = ox.Norm(x) <= 1.0
43
+
44
+ # Expressions form an AST
45
+ print(expr.pretty()) # Visualize the tree structure
46
+
47
+ Shape checking with automatic validation::
48
+
49
+ x = ox.State("x", shape=(3,))
50
+ y = ox.State("y", shape=(4,))
51
+
52
+ # This will raise ValueError during shape checking
53
+ try:
54
+ expr = x + y # Shapes (3,) and (4,) not broadcastable
55
+ expr.check_shape()
56
+ except ValueError as e:
57
+ print(f"Shape error: {e}")
58
+
59
+ Algebraic canonicalization::
60
+
61
+ x = ox.State("x", shape=(3,))
62
+ expr = x + 0 + (1 * x)
63
+ canonical = expr.canonicalize() # Simplifies to: x + x
64
+ """
65
+
66
+ import hashlib
67
+ import struct
68
+ from typing import Callable, Tuple, Union
69
+
70
+ import numpy as np
71
+
72
+
73
+ class Expr:
74
+ """Base class for symbolic expressions in optimization problems.
75
+
76
+ Expr is the foundation of the symbolic expression system in openscvx. It represents
77
+ nodes in an abstract syntax tree (AST) for mathematical expressions. Expressions
78
+ support:
79
+
80
+ - Arithmetic operations: +, -, *, /, @, **
81
+ - Comparison operations: ==, <=, >=
82
+ - Indexing and slicing: []
83
+ - Transposition: .T property
84
+ - Shape checking and validation
85
+ - Canonicalization (algebraic simplification)
86
+
87
+ All Expr subclasses implement a tree structure where each node can have child
88
+ expressions accessed via the children() method.
89
+
90
+ Attributes:
91
+ __array_priority__: Priority for operations with numpy arrays (set to 1000)
92
+
93
+ Note:
94
+ When used in operations with numpy arrays, Expr objects take precedence,
95
+ allowing symbolic expressions to wrap numeric values automatically.
96
+ """
97
+
98
+ # Give Expr objects higher priority than numpy arrays in operations
99
+ __array_priority__ = 1000
100
+
101
+ def __le__(self, other):
102
+ from .constraint import Inequality
103
+
104
+ return Inequality(self, to_expr(other))
105
+
106
+ def __ge__(self, other):
107
+ from .constraint import Inequality
108
+
109
+ return Inequality(to_expr(other), self)
110
+
111
+ def __eq__(self, other):
112
+ from .constraint import Equality
113
+
114
+ return Equality(self, to_expr(other))
115
+
116
+ def __add__(self, other):
117
+ from .arithmetic import Add
118
+
119
+ return Add(self, to_expr(other))
120
+
121
+ def __radd__(self, other):
122
+ from .arithmetic import Add
123
+
124
+ return Add(to_expr(other), self)
125
+
126
+ def __sub__(self, other):
127
+ from .arithmetic import Sub
128
+
129
+ return Sub(self, to_expr(other))
130
+
131
+ def __rsub__(self, other):
132
+ # e.g. 5 - a ⇒ Sub(Constant(5), a)
133
+ from .arithmetic import Sub
134
+
135
+ return Sub(to_expr(other), self)
136
+
137
+ def __truediv__(self, other):
138
+ from .arithmetic import Div
139
+
140
+ return Div(self, to_expr(other))
141
+
142
+ def __rtruediv__(self, other):
143
+ # e.g. 10 / a
144
+ from .arithmetic import Div
145
+
146
+ return Div(to_expr(other), self)
147
+
148
+ def __mul__(self, other):
149
+ from .arithmetic import Mul
150
+
151
+ return Mul(self, to_expr(other))
152
+
153
+ def __rmul__(self, other):
154
+ from .arithmetic import Mul
155
+
156
+ return Mul(to_expr(other), self)
157
+
158
+ def __matmul__(self, other):
159
+ from .arithmetic import MatMul
160
+
161
+ return MatMul(self, to_expr(other))
162
+
163
+ def __rmatmul__(self, other):
164
+ from .arithmetic import MatMul
165
+
166
+ return MatMul(to_expr(other), self)
167
+
168
+ def __rle__(self, other):
169
+ # other <= self => Inequality(other, self)
170
+ from .constraint import Inequality
171
+
172
+ return Inequality(to_expr(other), self)
173
+
174
+ def __rge__(self, other):
175
+ # other >= self => Inequality(self, other)
176
+ from .constraint import Inequality
177
+
178
+ return Inequality(self, to_expr(other))
179
+
180
+ def __req__(self, other):
181
+ # other == self => Equality(other, self)
182
+ from .constraint import Equality
183
+
184
+ return Equality(to_expr(other), self)
185
+
186
+ def __neg__(self):
187
+ from .arithmetic import Neg
188
+
189
+ return Neg(self)
190
+
191
+ def __pow__(self, other):
192
+ from .arithmetic import Power
193
+
194
+ return Power(self, to_expr(other))
195
+
196
+ def __rpow__(self, other):
197
+ from .arithmetic import Power
198
+
199
+ return Power(to_expr(other), self)
200
+
201
+ def __getitem__(self, idx):
202
+ from .array import Index
203
+
204
+ return Index(self, idx)
205
+
206
+ @property
207
+ def T(self):
208
+ """Transpose property for matrix expressions.
209
+
210
+ Returns:
211
+ Transpose: A Transpose expression wrapping this expression
212
+
213
+ Example:
214
+ Create a transpose:
215
+
216
+ A = ox.State("A", shape=(3, 4))
217
+ A_T = A.T # Creates Transpose(A), result shape (4, 3)
218
+ """
219
+ from .linalg import Transpose
220
+
221
+ return Transpose(self)
222
+
223
+ def at(self, k: int) -> "NodeReference":
224
+ """Reference this expression at a specific trajectory node.
225
+
226
+ This method enables inter-node constraints where you can reference
227
+ the value of an expression at different time steps. Common patterns
228
+ include rate limits and multi-step dependencies.
229
+
230
+ Args:
231
+ k: Absolute node index (integer) in the trajectory.
232
+ Can be positive (0, 1, 2, ...) or negative (-1 for last node).
233
+
234
+ Returns:
235
+ NodeReference: An expression representing this expression at node k
236
+
237
+ Example:
238
+ Rate limit constraint (applied across trajectory using a loop):
239
+
240
+ position = State("pos", shape=(3,))
241
+
242
+ # Create rate limit for each node
243
+ constraints = [
244
+ (ox.linalg.Norm(position.at(k) - position.at(k-1)) <= 0.1).at([k])
245
+ for k in range(1, N)
246
+ ]
247
+
248
+ Multi-step dependency:
249
+
250
+ state = State("x", shape=(1,))
251
+
252
+ # Fibonacci-like recurrence
253
+ constraints = [
254
+ (state.at(k) == state.at(k-1) + state.at(k-2)).at([k])
255
+ for k in range(2, N)
256
+ ]
257
+
258
+ Performance Note:
259
+ Cross-node constraints use dense Jacobian storage which can be memory-intensive
260
+ for large N (>100 nodes). See LoweredCrossNodeConstraint documentation for
261
+ details on memory usage and future sparse Jacobian support.
262
+ """
263
+ return NodeReference(self, k)
264
+
265
+ def children(self):
266
+ """Return the child expressions of this node.
267
+
268
+ Returns:
269
+ list: List of child Expr objects. Empty list for leaf nodes.
270
+ """
271
+ return []
272
+
273
+ def canonicalize(self) -> "Expr":
274
+ """
275
+ Return a canonical (simplified) form of this expression.
276
+
277
+ Canonicalization performs algebraic simplifications such as:
278
+ - Constant folding (e.g., 2 + 3 → 5)
279
+ - Identity elimination (e.g., x + 0 → x, x * 1 → x)
280
+ - Flattening nested operations (e.g., Add(Add(a, b), c) → Add(a, b, c))
281
+ - Algebraic rewrites (e.g., constraints to standard form)
282
+
283
+ Returns:
284
+ Expr: A canonical version of this expression
285
+
286
+ Raises:
287
+ NotImplementedError: If canonicalization is not implemented for this node type
288
+ """
289
+ raise NotImplementedError(f"canonicalize() not implemented for {self.__class__.__name__}")
290
+
291
+ def check_shape(self) -> Tuple[int, ...]:
292
+ """
293
+ Compute and validate the shape of this expression.
294
+
295
+ This method:
296
+ 1. Recursively checks shapes of all child expressions
297
+ 2. Validates that operations are shape-compatible (e.g., broadcasting rules)
298
+ 3. Returns the output shape of this expression
299
+
300
+ For example:
301
+ - A Parameter with shape (3, 4) returns (3, 4)
302
+ - MatMul of (3, 4) @ (4, 5) returns (3, 5)
303
+ - Sum of any shape returns () (scalar)
304
+ - Add broadcasts shapes like NumPy
305
+
306
+ Returns:
307
+ tuple: The shape of this expression as a tuple of integers.
308
+ Empty tuple () represents a scalar.
309
+
310
+ Raises:
311
+ NotImplementedError: If shape checking is not implemented for this node type
312
+ ValueError: If the expression has invalid shapes (e.g., incompatible dimensions)
313
+ """
314
+ raise NotImplementedError(f"check_shape() not implemented for {self.__class__.__name__}")
315
+
316
+ def pretty(self, indent=0):
317
+ """Generate a pretty-printed string representation of the expression tree.
318
+
319
+ Creates an indented, hierarchical view of the expression tree structure,
320
+ useful for debugging and visualization.
321
+
322
+ Args:
323
+ indent: Current indentation level (default: 0)
324
+
325
+ Returns:
326
+ str: Multi-line string representation of the expression tree
327
+
328
+ Example:
329
+ Pretty print an expression:
330
+
331
+ expr = (x + y) * z
332
+ print(expr.pretty())
333
+ # Mul
334
+ # Add
335
+ # State
336
+ # State
337
+ # State
338
+ """
339
+ pad = " " * indent
340
+ pad = " " * indent
341
+ lines = [f"{pad}{self.__class__.__name__}"]
342
+ for child in self.children():
343
+ lines.append(child.pretty(indent + 1))
344
+ return "\n".join(lines)
345
+
346
+ def _hash_into(self, hasher: "hashlib._Hash") -> None:
347
+ """Contribute this expression's structural identity to a hash.
348
+
349
+ This method is used to compute a structural hash of the expression tree
350
+ that is name-invariant (same structure = same hash regardless of variable names).
351
+
352
+ The default implementation hashes the class name and recursively hashes all
353
+ children. Subclasses with additional attributes (like Norm.ord, Index.index)
354
+ should override this to include those attributes.
355
+
356
+ Args:
357
+ hasher: A hashlib hash object to update
358
+ """
359
+ # Hash the class name to distinguish different node types
360
+ hasher.update(self.__class__.__name__.encode())
361
+ # Recursively hash all children
362
+ for child in self.children():
363
+ child._hash_into(hasher)
364
+
365
+ def structural_hash(self) -> bytes:
366
+ """Compute a structural hash of this expression.
367
+
368
+ Returns a hash that depends only on the mathematical structure of the
369
+ expression, not on variable names. Two expressions that are structurally
370
+ equivalent (same operations, same variable positions) will have the same hash.
371
+
372
+ Returns:
373
+ bytes: SHA-256 digest of the expression structure
374
+ """
375
+ hasher = hashlib.sha256()
376
+ self._hash_into(hasher)
377
+ return hasher.digest()
378
+
379
+
380
+ class Leaf(Expr):
381
+ """
382
+ Base class for leaf nodes (terminal expressions) in the symbolic expression tree.
383
+
384
+ Leaf nodes represent named symbolic variables that don't have child expressions.
385
+ This includes Parameters, Variables, States, and Controls.
386
+
387
+ Attributes:
388
+ name (str): Name identifier for the leaf node
389
+ _shape (tuple): Shape of the leaf node
390
+ """
391
+
392
+ def __init__(self, name: str, shape: tuple = ()):
393
+ """Initialize a Leaf node.
394
+
395
+ Args:
396
+ name (str): Name identifier for the leaf node
397
+ shape (tuple): Shape of the leaf node
398
+ """
399
+ super().__init__()
400
+ self.name = name
401
+ self._shape = shape
402
+
403
+ @property
404
+ def shape(self):
405
+ """Get the shape of the leaf node.
406
+
407
+ Returns:
408
+ tuple: Shape of the leaf node
409
+ """
410
+ return self._shape
411
+
412
+ def children(self):
413
+ """Leaf nodes have no children.
414
+
415
+ Returns:
416
+ list: Empty list since leaf nodes are terminal
417
+ """
418
+ return []
419
+
420
+ def canonicalize(self) -> "Expr":
421
+ """Leaf nodes are already in canonical form.
422
+
423
+ Returns:
424
+ Expr: Returns self since leaf nodes are already canonical
425
+ """
426
+ return self
427
+
428
+ def check_shape(self) -> Tuple[int, ...]:
429
+ """Return the shape of this leaf node.
430
+
431
+ Returns:
432
+ tuple: The shape of the leaf node
433
+ """
434
+ return self._shape
435
+
436
+ def _hash_into(self, hasher: "hashlib._Hash") -> None:
437
+ """Hash leaf node by class name and shape.
438
+
439
+ This base implementation hashes the class name and shape. Subclasses
440
+ like Variable and Parameter override this to add their specific
441
+ canonical identifiers (_slice for Variables, value for Parameters).
442
+
443
+ Args:
444
+ hasher: A hashlib hash object to update
445
+ """
446
+ hasher.update(self.__class__.__name__.encode())
447
+ hasher.update(str(self._shape).encode())
448
+
449
+ def __repr__(self):
450
+ """String representation of the leaf node.
451
+
452
+ Returns:
453
+ str: A string describing the leaf node
454
+ """
455
+ return f"{self.__class__.__name__}('{self.name}', shape={self.shape})"
456
+
457
+
458
+ class Parameter(Leaf):
459
+ """Parameter that can be changed at runtime without recompilation.
460
+
461
+ Parameters are symbolic variables with initial values that can be updated
462
+ through the problem's parameter dictionary. They allow for efficient
463
+ parameter sweeps without needing to recompile the optimization problem.
464
+
465
+ Example:
466
+ obs_center = ox.Parameter("obs_center", shape=(3,), value=np.array([1.0, 0.0, 0.0]))
467
+ # Later: problem.parameters["obs_center"] = new_value
468
+ """
469
+
470
+ def __init__(self, name: str, shape: tuple = (), value=None):
471
+ """Initialize a Parameter node.
472
+
473
+ Args:
474
+ name (str): Name identifier for the parameter
475
+ shape (tuple): Shape of the parameter (default: scalar)
476
+ value: Initial value for the parameter (required)
477
+ """
478
+ super().__init__(name, shape)
479
+ if value is None:
480
+ raise ValueError(f"Parameter '{name}' requires an initial value")
481
+ self.value = np.asarray(value, dtype=float)
482
+
483
+ def _hash_into(self, hasher: "hashlib._Hash") -> None:
484
+ """Hash Parameter by its shape only (value-invariant).
485
+
486
+ Parameters are hashed by shape only, not by value. This allows the same
487
+ compiled solver to be reused across parameter sweeps - only the structure
488
+ matters for compilation, not the actual values.
489
+
490
+ Args:
491
+ hasher: A hashlib hash object to update
492
+ """
493
+ hasher.update(b"Parameter")
494
+ hasher.update(str(self._shape).encode())
495
+
496
+
497
+ def to_expr(x: Union[Expr, float, int, np.ndarray]) -> Expr:
498
+ """Convert a value to an Expr if it is not already one.
499
+
500
+ This is a convenience function that wraps numeric values and arrays as Constant
501
+ expressions, while leaving Expr instances unchanged. Used internally by operators
502
+ to ensure operands are proper Expr objects.
503
+
504
+ Args:
505
+ x: Value to convert - can be an Expr, numeric scalar, or numpy array
506
+
507
+ Returns:
508
+ The input if it's already an Expr, otherwise a Constant wrapping the value
509
+ """
510
+ return x if isinstance(x, Expr) else Constant(np.array(x))
511
+
512
+
513
+ def traverse(expr: Expr, visit: Callable[[Expr], None]):
514
+ """Depth-first traversal of an expression tree.
515
+
516
+ Visits each node in the expression tree by applying the visit function to the
517
+ current node, then recursively visiting all children.
518
+
519
+ Args:
520
+ expr: Root expression node to start traversal from
521
+ visit: Callback function applied to each node during traversal
522
+ """
523
+ visit(expr)
524
+ for child in expr.children():
525
+ traverse(child, visit)
526
+
527
+
528
+ class Constant(Expr):
529
+ """Constant value expression.
530
+
531
+ Represents a constant numeric value in the expression tree. Constants are
532
+ automatically normalized (squeezed) upon construction to ensure consistency.
533
+
534
+ Attributes:
535
+ value: The numpy array representing the constant value (squeezed)
536
+
537
+ Example:
538
+ Define constants:
539
+
540
+ c1 = Constant(5.0) # Scalar constant
541
+ c2 = Constant([1, 2, 3]) # Vector constant
542
+ c3 = to_expr(10) # Also creates a Constant
543
+ """
544
+
545
+ def __init__(self, value: np.ndarray):
546
+ """Initialize a constant expression.
547
+
548
+ Args:
549
+ value: Numeric value or numpy array to wrap as a constant.
550
+ Will be converted to numpy array and squeezed.
551
+ """
552
+ # Normalize immediately upon construction to ensure consistency
553
+ # This ensures Constant(5.0) and Constant([5.0]) create identical objects
554
+ if not isinstance(value, np.ndarray):
555
+ value = np.array(value, dtype=float)
556
+ self.value = np.squeeze(value)
557
+
558
+ def canonicalize(self) -> "Expr":
559
+ """Constants are already in canonical form.
560
+
561
+ Returns:
562
+ Expr: Returns self since constants are already canonical
563
+ """
564
+ return self
565
+
566
+ def check_shape(self) -> Tuple[int, ...]:
567
+ """Return the shape of this constant's value.
568
+
569
+ Returns:
570
+ tuple: The shape of the constant's numpy array value
571
+ """
572
+ # Verify the invariant: constants should already be squeezed during construction
573
+ original_shape = self.value.shape
574
+ squeezed_shape = np.squeeze(self.value).shape
575
+ if original_shape != squeezed_shape:
576
+ raise ValueError(
577
+ f"Constant not properly normalized: has shape {original_shape} "
578
+ "but should have shape {squeezed_shape}. "
579
+ "Constants should be squeezed during construction."
580
+ )
581
+ return self.value.shape
582
+
583
+ def _hash_into(self, hasher: "hashlib._Hash") -> None:
584
+ """Hash constant by its value.
585
+
586
+ Constants are hashed by their actual numeric value, ensuring that
587
+ expressions with the same constant values produce the same hash.
588
+
589
+ Args:
590
+ hasher: A hashlib hash object to update
591
+ """
592
+ hasher.update(b"Constant")
593
+ hasher.update(str(self.value.shape).encode())
594
+ hasher.update(self.value.tobytes())
595
+
596
+ def __repr__(self):
597
+ # Show clean representation - always show as Python values, not numpy arrays
598
+ if self.value.ndim == 0:
599
+ # Scalar: show as plain number
600
+ return f"Const({self.value.item()!r})"
601
+ else:
602
+ # Array: show as Python list for readability
603
+ return f"Const({self.value.tolist()!r})"
604
+
605
+
606
+ class NodeReference(Expr):
607
+ """Reference to a variable at a specific trajectory node.
608
+
609
+ NodeReference enables inter-node constraints by allowing you to reference
610
+ the value of a state or control variable at a specific discrete time point
611
+ (node) in the trajectory. This is essential for expressing temporal relationships
612
+ such as:
613
+
614
+ - Rate limits and smoothness constraints
615
+ - Multi-step dependencies and recurrence relations
616
+ - Constraints coupling specific nodes
617
+
618
+ Attributes:
619
+ base: The expression (typically a Leaf like State or Control) being referenced
620
+ node_idx: Trajectory node index (integer, can be negative for end-indexing)
621
+
622
+ Example:
623
+ Rate limit across trajectory:
624
+
625
+ position = State("pos", shape=(3,))
626
+
627
+ # Create rate limit constraints for all nodes
628
+ constraints = [
629
+ (ox.linalg.Norm(position.at(k) - position.at(k-1)) <= 0.1).at([k])
630
+ for k in range(1, N)
631
+ ]
632
+
633
+ Multi-step dependency:
634
+
635
+ state = State("x", shape=(1,))
636
+
637
+ # Fibonacci-like recurrence at each node
638
+ constraints = [
639
+ (state.at(k) == state.at(k-1) + state.at(k-2)).at([k])
640
+ for k in range(2, N)
641
+ ]
642
+
643
+ Coupling specific nodes:
644
+
645
+ # Constrain distance between nodes 5 and 10
646
+ coupling = (position.at(10) - position.at(5) <= threshold).at([10])
647
+
648
+ Performance Note:
649
+ Cross-node constraints use dense Jacobian storage. For details on memory
650
+ usage and performance implications, see LoweredCrossNodeConstraint documentation.
651
+
652
+ Note:
653
+ NodeReference is typically created via the `.at(k)` method on expressions
654
+ rather than constructed directly.
655
+ """
656
+
657
+ def __init__(self, base: Expr, node_idx: int):
658
+ """Initialize a NodeReference.
659
+
660
+ Args:
661
+ base: Expression to reference at a specific node (typically a Leaf)
662
+ node_idx: Absolute trajectory node index (integer)
663
+ Supports negative indexing (e.g., -1 for last node)
664
+
665
+ Raises:
666
+ TypeError: If node_idx is not an integer
667
+ """
668
+ if not isinstance(node_idx, int):
669
+ raise TypeError(f"Node index must be an integer, got {type(node_idx).__name__}")
670
+
671
+ self.node_idx = node_idx
672
+ self.base = base
673
+
674
+ def children(self):
675
+ """Return the base expression as the only child.
676
+
677
+ Returns:
678
+ list: Single-element list containing the base expression
679
+ """
680
+ return [self.base]
681
+
682
+ def canonicalize(self) -> "Expr":
683
+ """Canonicalize by canonicalizing the base expression.
684
+
685
+ Returns:
686
+ NodeReference: A new NodeReference with canonicalized base
687
+ """
688
+ canon_base = self.base.canonicalize()
689
+ return NodeReference(canon_base, self.node_idx)
690
+
691
+ def check_shape(self) -> Tuple[int, ...]:
692
+ """Return the shape of the base expression.
693
+
694
+ NodeReference doesn't change the shape of the underlying expression,
695
+ it just references it at a specific time point.
696
+
697
+ Returns:
698
+ tuple: The shape of the base expression
699
+ """
700
+ return self.base.check_shape()
701
+
702
+ def _hash_into(self, hasher: "hashlib._Hash") -> None:
703
+ """Hash NodeReference including its node index.
704
+
705
+ Args:
706
+ hasher: A hashlib hash object to update
707
+ """
708
+ hasher.update(b"NodeReference")
709
+ # Hash the node index (signed int)
710
+ hasher.update(struct.pack(">i", self.node_idx))
711
+ # Hash the base expression
712
+ self.base._hash_into(hasher)
713
+
714
+ def __repr__(self):
715
+ """String representation of the NodeReference.
716
+
717
+ Returns:
718
+ str: String showing the base expression and node index
719
+ """
720
+ return f"{self.base!r}.at({self.node_idx})"