openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,517 @@
1
+ """Arithmetic operations for symbolic expressions.
2
+
3
+ This module provides fundamental arithmetic operations that form the building blocks
4
+ of symbolic expressions in openscvx. These operations are created automatically through
5
+ operator overloading on Expr objects.
6
+
7
+ Arithmetic Operations:
8
+
9
+ - **Binary operations:** `Add`, `Sub`, `Mul`, `Div`, `MatMul`, `Power` - Standard arithmetic
10
+ - **Unary operations:** `Neg` - Negation (unary minus)
11
+
12
+ All arithmetic operations support:
13
+ - Automatic canonicalization (constant folding, identity elimination, flattening)
14
+ - Broadcasting following NumPy rules (except MatMul which follows linear algebra rules)
15
+ - Shape checking and validation
16
+
17
+ Example:
18
+ Arithmetic operations are created via operator overloading::
19
+
20
+ import openscvx as ox
21
+
22
+ x = ox.State("x", shape=(3,))
23
+ y = ox.State("y", shape=(3,))
24
+
25
+ # Element-wise operations
26
+ z = x + y # Creates Add(x, y)
27
+ w = x * 2 # Creates Mul(x, Constant(2))
28
+ neg_x = -x # Creates Neg(x)
29
+
30
+ # Matrix multiplication
31
+ A = ox.State("A", shape=(3, 3))
32
+ b = A @ x # Creates MatMul(A, x)
33
+ """
34
+
35
+ from typing import Tuple
36
+
37
+ import numpy as np
38
+
39
+ from .expr import Constant, Expr, to_expr
40
+
41
+
42
+ class Add(Expr):
43
+ """Addition operation for symbolic expressions.
44
+
45
+ Represents element-wise addition of two or more expressions. Supports broadcasting
46
+ following NumPy rules. Can be created using the + operator on Expr objects.
47
+
48
+ Attributes:
49
+ terms: List of expression operands to add together
50
+
51
+ Example:
52
+ Define an Add expression:
53
+
54
+ x = ox.State("x", shape=(3,))
55
+ y = ox.State("y", shape=(3,))
56
+ z = x + y + 5 # Creates Add(x, y, Constant(5))
57
+ """
58
+
59
+ def __init__(self, *args):
60
+ """Initialize an addition operation.
61
+
62
+ Args:
63
+ *args: Two or more expressions to add together
64
+
65
+ Raises:
66
+ ValueError: If fewer than two operands are provided
67
+ """
68
+ if len(args) < 2:
69
+ raise ValueError("Add requires two or more operands")
70
+ self.terms = [to_expr(a) for a in args]
71
+
72
+ def children(self):
73
+ return list(self.terms)
74
+
75
+ def canonicalize(self) -> "Expr":
76
+ """Canonicalize addition: flatten, fold constants, and eliminate zeros.
77
+
78
+ Returns:
79
+ Expr: Canonical form of the addition expression
80
+ """
81
+ terms = []
82
+ const_vals = []
83
+
84
+ for t in self.terms:
85
+ c = t.canonicalize()
86
+ if isinstance(c, Add):
87
+ terms.extend(c.terms)
88
+ elif isinstance(c, Constant):
89
+ const_vals.append(c.value)
90
+ else:
91
+ terms.append(c)
92
+
93
+ if const_vals:
94
+ total = sum(const_vals)
95
+ # If not all-zero, keep it
96
+ if not (isinstance(total, np.ndarray) and np.all(total == 0)):
97
+ terms.append(Constant(total))
98
+
99
+ if not terms:
100
+ return Constant(np.array(0))
101
+ if len(terms) == 1:
102
+ return terms[0]
103
+ return Add(*terms)
104
+
105
+ def check_shape(self) -> Tuple[int, ...]:
106
+ """Check shape compatibility and compute broadcasted result shape like NumPy.
107
+
108
+ Returns:
109
+ tuple: The broadcasted shape of all operands
110
+
111
+ Raises:
112
+ ValueError: If operand shapes are not broadcastable
113
+ """
114
+ shapes = [child.check_shape() for child in self.children()]
115
+ try:
116
+ return np.broadcast_shapes(*shapes)
117
+ except ValueError as e:
118
+ raise ValueError(f"Add shapes not broadcastable: {shapes}") from e
119
+
120
+ def __repr__(self):
121
+ inner = " + ".join(repr(e) for e in self.terms)
122
+ return f"({inner})"
123
+
124
+
125
+ class Sub(Expr):
126
+ """Subtraction operation for symbolic expressions.
127
+
128
+ Represents element-wise subtraction (left - right). Supports broadcasting
129
+ following NumPy rules. Can be created using the - operator on Expr objects.
130
+
131
+ Attributes:
132
+ left: Left-hand side expression (minuend)
133
+ right: Right-hand side expression (subtrahend)
134
+
135
+ Example:
136
+ Define a Sub expression:
137
+
138
+ x = ox.State("x", shape=(3,))
139
+ y = ox.State("y", shape=(3,))
140
+ z = x - y # Creates Sub(x, y)
141
+ """
142
+
143
+ def __init__(self, left, right):
144
+ """Initialize a subtraction operation.
145
+
146
+ Args:
147
+ left: Expression to subtract from (minuend)
148
+ right: Expression to subtract (subtrahend)
149
+ """
150
+ self.left = left
151
+ self.right = right
152
+
153
+ def children(self):
154
+ return [self.left, self.right]
155
+
156
+ def canonicalize(self) -> "Expr":
157
+ """Canonicalize subtraction: fold constants if both sides are constants.
158
+
159
+ Returns:
160
+ Expr: Canonical form of the subtraction expression
161
+ """
162
+ left = self.left.canonicalize()
163
+ right = self.right.canonicalize()
164
+ if isinstance(left, Constant) and isinstance(right, Constant):
165
+ return Constant(left.value - right.value)
166
+ return Sub(left, right)
167
+
168
+ def check_shape(self) -> Tuple[int, ...]:
169
+ """Check shape compatibility and compute broadcasted result shape like NumPy.
170
+
171
+ Returns:
172
+ tuple: The broadcasted shape of all operands
173
+
174
+ Raises:
175
+ ValueError: If operand shapes are not broadcastable
176
+ """
177
+ shapes = [child.check_shape() for child in self.children()]
178
+ try:
179
+ return np.broadcast_shapes(*shapes)
180
+ except ValueError as e:
181
+ raise ValueError(f"Sub shapes not broadcastable: {shapes}") from e
182
+
183
+ def __repr__(self):
184
+ return f"({self.left!r} - {self.right!r})"
185
+
186
+
187
+ class Mul(Expr):
188
+ """Element-wise multiplication operation for symbolic expressions.
189
+
190
+ Represents element-wise (Hadamard) multiplication of two or more expressions.
191
+ Supports broadcasting following NumPy rules. Can be created using the * operator
192
+ on Expr objects. For matrix multiplication, use MatMul or the @ operator.
193
+
194
+ Attributes:
195
+ factors: List of expression operands to multiply together
196
+
197
+ Example:
198
+ Define a Mul expression:
199
+
200
+ x = ox.State("x", shape=(3,))
201
+ y = ox.State("y", shape=(3,))
202
+ z = x * y * 2 # Creates Mul(x, y, Constant(2))
203
+ """
204
+
205
+ def __init__(self, *args):
206
+ """Initialize an element-wise multiplication operation.
207
+
208
+ Args:
209
+ *args: Two or more expressions to multiply together
210
+
211
+ Raises:
212
+ ValueError: If fewer than two operands are provided
213
+ """
214
+ if len(args) < 2:
215
+ raise ValueError("Mul requires two or more operands")
216
+ self.factors = [to_expr(a) for a in args]
217
+
218
+ def children(self):
219
+ return list(self.factors)
220
+
221
+ def canonicalize(self) -> "Expr":
222
+ """Canonicalize multiplication: flatten, fold constants, and eliminating ones.
223
+
224
+ Returns:
225
+ Expr: Canonical form of the multiplication expression
226
+ """
227
+ factors = []
228
+ const_vals = []
229
+
230
+ for f in self.factors:
231
+ c = f.canonicalize()
232
+ if isinstance(c, Mul):
233
+ factors.extend(c.factors)
234
+ elif isinstance(c, Constant):
235
+ const_vals.append(c.value)
236
+ else:
237
+ factors.append(c)
238
+
239
+ if const_vals:
240
+ # Multiply constants element-wise (broadcasting), not reducing with prod
241
+ prod = const_vals[0]
242
+ for val in const_vals[1:]:
243
+ prod = prod * val
244
+
245
+ # If prod != 1, keep it
246
+ # Check both scalar and array cases
247
+ is_identity = False
248
+ if isinstance(prod, np.ndarray):
249
+ is_identity = np.all(prod == 1)
250
+ else:
251
+ is_identity = prod == 1
252
+
253
+ if not is_identity:
254
+ factors.append(Constant(prod))
255
+
256
+ if not factors:
257
+ return Constant(np.array(1))
258
+ if len(factors) == 1:
259
+ return factors[0]
260
+ return Mul(*factors)
261
+
262
+ def check_shape(self) -> Tuple[int, ...]:
263
+ """Check shape compatibility and compute broadcasted result shape like NumPy.
264
+
265
+
266
+ Returns:
267
+ tuple: The broadcasted shape of all operands
268
+
269
+ Raises:
270
+ ValueError: If operand shapes are not broadcastable
271
+ """
272
+ shapes = [child.check_shape() for child in self.children()]
273
+ try:
274
+ return np.broadcast_shapes(*shapes)
275
+ except ValueError as e:
276
+ raise ValueError(f"Mul shapes not broadcastable: {shapes}") from e
277
+
278
+ def __repr__(self):
279
+ inner = " * ".join(repr(e) for e in self.factors)
280
+ return f"({inner})"
281
+
282
+
283
+ class Div(Expr):
284
+ """Element-wise division operation for symbolic expressions.
285
+
286
+ Represents element-wise division (left / right). Supports broadcasting
287
+ following NumPy rules. Can be created using the / operator on Expr objects.
288
+
289
+ Attributes:
290
+ left: Numerator expression
291
+ right: Denominator expression
292
+
293
+ Example:
294
+ Define a Div expression
295
+
296
+ x = ox.State("x", shape=(3,))
297
+ y = ox.State("y", shape=(3,))
298
+ z = x / y # Creates Div(x, y)
299
+ """
300
+
301
+ def __init__(self, left, right):
302
+ """Initialize a division operation.
303
+
304
+ Args:
305
+ left: Expression for the numerator
306
+ right: Expression for the denominator
307
+ """
308
+ self.left = left
309
+ self.right = right
310
+
311
+ def children(self):
312
+ return [self.left, self.right]
313
+
314
+ def canonicalize(self) -> "Expr":
315
+ """Canonicalize division: fold constants if both sides are constants.
316
+
317
+ Returns:
318
+ Expr: Canonical form of the division expression
319
+ """
320
+ lhs = self.left.canonicalize()
321
+ rhs = self.right.canonicalize()
322
+ if isinstance(lhs, Constant) and isinstance(rhs, Constant):
323
+ return Constant(lhs.value / rhs.value)
324
+ return Div(lhs, rhs)
325
+
326
+ def check_shape(self) -> Tuple[int, ...]:
327
+ """Check shape compatibility and compute broadcasted result shape like NumPy.
328
+
329
+ Returns:
330
+ tuple: The broadcasted shape of both operands
331
+
332
+ Raises:
333
+ ValueError: If operand shapes are not broadcastable
334
+ """
335
+ shapes = [child.check_shape() for child in self.children()]
336
+ try:
337
+ return np.broadcast_shapes(*shapes)
338
+ except ValueError as e:
339
+ raise ValueError(f"Div shapes not broadcastable: {shapes}") from e
340
+
341
+ def __repr__(self):
342
+ return f"({self.left!r} / {self.right!r})"
343
+
344
+
345
+ class MatMul(Expr):
346
+ """Matrix multiplication operation for symbolic expressions.
347
+
348
+ Represents matrix multiplication following standard linear algebra rules.
349
+ Can be created using the @ operator on Expr objects. Handles:
350
+ - Matrix @ Matrix: (m,n) @ (n,k) -> (m,k)
351
+ - Matrix @ Vector: (m,n) @ (n,) -> (m,)
352
+ - Vector @ Matrix: (m,) @ (m,n) -> (n,)
353
+ - Vector @ Vector: (m,) @ (m,) -> scalar
354
+
355
+ Attributes:
356
+ left: Left-hand side expression
357
+ right: Right-hand side expression
358
+
359
+ Example:
360
+ Define a MatMul expression:
361
+
362
+ A = ox.State("A", shape=(3, 4))
363
+ x = ox.State("x", shape=(4,))
364
+ y = A @ x # Creates MatMul(A, x), result shape (3,)
365
+ """
366
+
367
+ def __init__(self, left, right):
368
+ """Initialize a matrix multiplication operation.
369
+
370
+ Args:
371
+ left: Left-hand side expression for matrix multiplication
372
+ right: Right-hand side expression for matrix multiplication
373
+ """
374
+ self.left = left
375
+ self.right = right
376
+
377
+ def children(self):
378
+ return [self.left, self.right]
379
+
380
+ def canonicalize(self) -> "Expr":
381
+ left = self.left.canonicalize()
382
+ right = self.right.canonicalize()
383
+ return MatMul(left, right)
384
+
385
+ def check_shape(self) -> Tuple[int, ...]:
386
+ """Check matrix multiplication shape compatibility and return result shape."""
387
+ L, R = self.left.check_shape(), self.right.check_shape()
388
+
389
+ # Handle different matmul cases:
390
+ # Matrix @ Matrix: (m,n) @ (n,k) -> (m,k)
391
+ # Matrix @ Vector: (m,n) @ (n,) -> (m,)
392
+ # Vector @ Matrix: (m,) @ (m,n) -> (n,)
393
+ # Vector @ Vector: (m,) @ (m,) -> ()
394
+
395
+ if len(L) == 0 or len(R) == 0:
396
+ raise ValueError(f"MatMul requires at least 1D operands: {L} @ {R}")
397
+
398
+ if len(L) == 1 and len(R) == 1:
399
+ # Vector @ Vector -> scalar
400
+ if L[0] != R[0]:
401
+ raise ValueError(f"MatMul incompatible: {L} @ {R}")
402
+ return ()
403
+ elif len(L) == 1:
404
+ # Vector @ Matrix: (m,) @ (m,n) -> (n,)
405
+ if len(R) < 2 or L[0] != R[-2]:
406
+ raise ValueError(f"MatMul incompatible: {L} @ {R}")
407
+ return R[-1:]
408
+ elif len(R) == 1:
409
+ # Matrix @ Vector: (m,n) @ (n,) -> (m,)
410
+ if len(L) < 2 or L[-1] != R[0]:
411
+ raise ValueError(f"MatMul incompatible: {L} @ {R}")
412
+ return L[:-1]
413
+ else:
414
+ # Matrix @ Matrix: (...,m,n) @ (...,n,k) -> (...,m,k)
415
+ if len(L) < 2 or len(R) < 2 or L[-1] != R[-2]:
416
+ raise ValueError(f"MatMul incompatible: {L} @ {R}")
417
+ return L[:-1] + (R[-1],)
418
+
419
+ def __repr__(self):
420
+ return f"({self.left!r} * {self.right!r})"
421
+
422
+
423
+ class Neg(Expr):
424
+ """Negation operation for symbolic expressions.
425
+
426
+ Represents element-wise negation (unary minus). Can be created using the
427
+ unary - operator on Expr objects.
428
+
429
+ Attributes:
430
+ operand: Expression to negate
431
+
432
+ Example:
433
+ Define a Neg expression:
434
+
435
+ x = ox.State("x", shape=(3,))
436
+ y = -x # Creates Neg(x)
437
+ """
438
+
439
+ def __init__(self, operand):
440
+ """Initialize a negation operation.
441
+
442
+ Args:
443
+ operand: Expression to negate
444
+ """
445
+ self.operand = operand
446
+
447
+ def children(self):
448
+ return [self.operand]
449
+
450
+ def canonicalize(self) -> "Expr":
451
+ """Canonicalize negation: fold constant negations.
452
+
453
+ Returns:
454
+ Expr: Canonical form of the negation expression
455
+ """
456
+ o = self.operand.canonicalize()
457
+ if isinstance(o, Constant):
458
+ return Constant(-o.value)
459
+ return Neg(o)
460
+
461
+ def check_shape(self) -> Tuple[int, ...]:
462
+ """Negation preserves the shape of its operand."""
463
+ return self.operand.check_shape()
464
+
465
+ def __repr__(self):
466
+ return f"(-{self.operand!r})"
467
+
468
+
469
+ class Power(Expr):
470
+ """Element-wise power operation for symbolic expressions.
471
+
472
+ Represents element-wise exponentiation (base ** exponent). Supports broadcasting
473
+ following NumPy rules. Can be created using the ** operator on Expr objects.
474
+
475
+ Attributes:
476
+ base: Base expression
477
+ exponent: Exponent expression
478
+
479
+ Example:
480
+ Define a Power expression:
481
+
482
+ x = ox.State("x", shape=(3,))
483
+ y = x ** 2 # Creates Power(x, Constant(2))
484
+ """
485
+
486
+ def __init__(self, base, exponent):
487
+ """Initialize a power operation.
488
+
489
+ Args:
490
+ base: Base expression
491
+ exponent: Exponent expression
492
+ """
493
+ self.base = to_expr(base)
494
+ self.exponent = to_expr(exponent)
495
+
496
+ def children(self):
497
+ return [self.base, self.exponent]
498
+
499
+ def canonicalize(self) -> "Expr":
500
+ """Canonicalize power by canonicalizing base and exponent.
501
+
502
+ Returns:
503
+ Expr: Canonical form of the power expression
504
+ """
505
+ base = self.base.canonicalize()
506
+ exponent = self.exponent.canonicalize()
507
+ return Power(base, exponent)
508
+
509
+ def check_shape(self) -> Tuple[int, ...]:
510
+ shapes = [child.check_shape() for child in self.children()]
511
+ try:
512
+ return np.broadcast_shapes(*shapes)
513
+ except ValueError as e:
514
+ raise ValueError(f"Power shapes not broadcastable: {shapes}") from e
515
+
516
+ def __repr__(self):
517
+ return f"({self.base!r})**({self.exponent!r})"