openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,699 @@
1
+ """Mathematical functions for symbolic expressions.
2
+
3
+ This module provides common mathematical operations used in optimization problems,
4
+ including trigonometric functions, exponential functions, and smooth approximations
5
+ of non-differentiable operations. All functions are element-wise and preserve the
6
+ shape of their inputs.
7
+
8
+ Function Categories:
9
+ - **Trigonometric:** `Sin`, `Cos`, `Tan` - Standard trigonometric functions
10
+ - **Exponential and Roots:** `Exp`, `Log`, `Sqrt`, `Square` - Exponential, logarithm, square
11
+ root, and squaring operations
12
+ - **Absolute Value:** `Abs` - Element-wise absolute value function
13
+ - **Smooth Approximations:** `PositivePart`, `Huber`, `SmoothReLU` - Smooth, differentiable
14
+ approximations of non-smooth functions like max(0, x) and absolute value
15
+ - **Reductions:** `Max` - Maximum over elements
16
+ - **Smooth Maximum:** `LogSumExp` - Log-sum-exp function, a smooth approximation to maximum
17
+
18
+ Example:
19
+ Using trigonometric functions in dynamics::
20
+
21
+ import openscvx as ox
22
+
23
+ # Pendulum dynamics: theta_ddot = -g/L * sin(theta)
24
+ theta = ox.State("theta", shape=(1,))
25
+ theta_dot = ox.State("theta_dot", shape=(1,))
26
+ g, L = 9.81, 1.0
27
+
28
+ theta_ddot = -(g / L) * ox.Sin(theta)
29
+
30
+ Smooth penalty functions for constraints::
31
+
32
+ # Soft constraint using smooth ReLU
33
+ x = ox.Variable("x", shape=(3,))
34
+ penalty = ox.SmoothReLU(ox.Norm(x) - 1.0) # Penalize norm > 1
35
+ """
36
+
37
+ import hashlib
38
+ import struct
39
+ from typing import Tuple
40
+
41
+ import numpy as np
42
+
43
+ from .expr import Expr, to_expr
44
+
45
+
46
+ class Sin(Expr):
47
+ """Element-wise sine function for symbolic expressions.
48
+
49
+ Computes the sine of each element in the operand. Preserves the shape
50
+ of the input expression.
51
+
52
+ Attributes:
53
+ operand: Expression to apply sine function to
54
+
55
+ Example:
56
+ Define a Sin expression:
57
+
58
+ theta = Variable("theta", shape=(3,))
59
+ sin_theta = Sin(theta)
60
+ """
61
+
62
+ def __init__(self, operand):
63
+ """Initialize a sine operation.
64
+
65
+ Args:
66
+ operand: Expression to apply sine function to
67
+ """
68
+ self.operand = operand
69
+
70
+ def children(self):
71
+ return [self.operand]
72
+
73
+ def canonicalize(self) -> "Expr":
74
+ operand = self.operand.canonicalize()
75
+ return Sin(operand)
76
+
77
+ def check_shape(self) -> Tuple[int, ...]:
78
+ """Sin preserves the shape of its operand."""
79
+ return self.operand.check_shape()
80
+
81
+ def __repr__(self):
82
+ return f"(sin({self.operand!r}))"
83
+
84
+
85
+ class Cos(Expr):
86
+ """Element-wise cosine function for symbolic expressions.
87
+
88
+ Computes the cosine of each element in the operand. Preserves the shape
89
+ of the input expression.
90
+
91
+ Attributes:
92
+ operand: Expression to apply cosine function to
93
+
94
+ Example:
95
+ Define a Cos expression:
96
+
97
+ theta = Variable("theta", shape=(3,))
98
+ cos_theta = Cos(theta)
99
+ """
100
+
101
+ def __init__(self, operand):
102
+ """Initialize a cosine operation.
103
+
104
+ Args:
105
+ operand: Expression to apply cosine function to
106
+ """
107
+ self.operand = operand
108
+
109
+ def children(self):
110
+ return [self.operand]
111
+
112
+ def canonicalize(self) -> "Expr":
113
+ operand = self.operand.canonicalize()
114
+ return Cos(operand)
115
+
116
+ def check_shape(self) -> Tuple[int, ...]:
117
+ """Cos preserves the shape of its operand."""
118
+ return self.operand.check_shape()
119
+
120
+ def __repr__(self):
121
+ return f"(cos({self.operand!r}))"
122
+
123
+
124
+ class Tan(Expr):
125
+ """Element-wise tangent function for symbolic expressions.
126
+
127
+ Computes the tangent of each element in the operand. Preserves the shape
128
+ of the input expression.
129
+
130
+ Attributes:
131
+ operand: Expression to apply tangent function to
132
+
133
+ Example:
134
+ Define a Tan expression:
135
+
136
+ theta = Variable("theta", shape=(3,))
137
+ tan_theta = Tan(theta)
138
+
139
+ Note:
140
+ Tan is only supported for JAX lowering. CVXPy lowering will raise
141
+ NotImplementedError since tangent is not DCP-compliant.
142
+ """
143
+
144
+ def __init__(self, operand):
145
+ """Initialize a tangent operation.
146
+
147
+ Args:
148
+ operand: Expression to apply tangent function to
149
+ """
150
+ self.operand = operand
151
+
152
+ def children(self):
153
+ return [self.operand]
154
+
155
+ def canonicalize(self) -> "Expr":
156
+ operand = self.operand.canonicalize()
157
+ return Tan(operand)
158
+
159
+ def check_shape(self) -> Tuple[int, ...]:
160
+ """Tan preserves the shape of its operand."""
161
+ return self.operand.check_shape()
162
+
163
+ def __repr__(self):
164
+ return f"(tan({self.operand!r}))"
165
+
166
+
167
+ class Square(Expr):
168
+ """Element-wise square function for symbolic expressions.
169
+
170
+ Computes the square (x^2) of each element in the operand. Preserves the
171
+ shape of the input expression. This is more efficient than using Power(x, 2)
172
+ for some optimization backends.
173
+
174
+ Attributes:
175
+ x: Expression to square
176
+
177
+ Example:
178
+ Define a Square expression:
179
+
180
+ v = Variable("v", shape=(3,))
181
+ v_squared = Square(v) # Equivalent to v ** 2
182
+ """
183
+
184
+ def __init__(self, x):
185
+ """Initialize a square operation.
186
+
187
+ Args:
188
+ x: Expression to square
189
+ """
190
+ self.x = to_expr(x)
191
+
192
+ def children(self):
193
+ return [self.x]
194
+
195
+ def canonicalize(self) -> "Expr":
196
+ x = self.x.canonicalize()
197
+ return Square(x)
198
+
199
+ def check_shape(self) -> Tuple[int, ...]:
200
+ """x^2 preserves the shape of x."""
201
+ return self.x.check_shape()
202
+
203
+ def __repr__(self):
204
+ return f"({self.x!r})^2"
205
+
206
+
207
+ class Sqrt(Expr):
208
+ """Element-wise square root function for symbolic expressions.
209
+
210
+ Computes the square root of each element in the operand. Preserves the
211
+ shape of the input expression.
212
+
213
+ Attributes:
214
+ operand: Expression to apply square root to
215
+
216
+ Example:
217
+ Define a Sqrt expression:
218
+
219
+ x = Variable("x", shape=(3,))
220
+ sqrt_x = Sqrt(x)
221
+ """
222
+
223
+ def __init__(self, operand):
224
+ """Initialize a square root operation.
225
+
226
+ Args:
227
+ operand: Expression to apply square root to
228
+ """
229
+ self.operand = to_expr(operand)
230
+
231
+ def children(self):
232
+ return [self.operand]
233
+
234
+ def canonicalize(self) -> "Expr":
235
+ operand = self.operand.canonicalize()
236
+ return Sqrt(operand)
237
+
238
+ def check_shape(self) -> Tuple[int, ...]:
239
+ """Sqrt preserves the shape of its operand."""
240
+ return self.operand.check_shape()
241
+
242
+ def __repr__(self):
243
+ return f"sqrt({self.operand!r})"
244
+
245
+
246
+ class Exp(Expr):
247
+ """Element-wise exponential function for symbolic expressions.
248
+
249
+ Computes e^x for each element in the operand, where e is Euler's number.
250
+ Preserves the shape of the input expression.
251
+
252
+ Attributes:
253
+ operand: Expression to apply exponential function to
254
+
255
+ Example:
256
+ Define an Exp expression:
257
+
258
+ x = Variable("x", shape=(3,))
259
+ exp_x = Exp(x)
260
+ """
261
+
262
+ def __init__(self, operand):
263
+ """Initialize an exponential operation.
264
+
265
+ Args:
266
+ operand: Expression to apply exponential function to
267
+ """
268
+ self.operand = to_expr(operand)
269
+
270
+ def children(self):
271
+ return [self.operand]
272
+
273
+ def canonicalize(self) -> "Expr":
274
+ operand = self.operand.canonicalize()
275
+ return Exp(operand)
276
+
277
+ def check_shape(self) -> Tuple[int, ...]:
278
+ """Exp preserves the shape of its operand."""
279
+ return self.operand.check_shape()
280
+
281
+ def __repr__(self):
282
+ return f"exp({self.operand!r})"
283
+
284
+
285
+ class Log(Expr):
286
+ """Element-wise natural logarithm function for symbolic expressions.
287
+
288
+ Computes the natural logarithm (base e) of each element in the operand.
289
+ Preserves the shape of the input expression.
290
+
291
+ Attributes:
292
+ operand: Expression to apply logarithm to
293
+
294
+ Example:
295
+ Define a Log expression:
296
+
297
+ x = Variable("x", shape=(3,))
298
+ log_x = Log(x)
299
+ """
300
+
301
+ def __init__(self, operand):
302
+ """Initialize a natural logarithm operation.
303
+
304
+ Args:
305
+ operand: Expression to apply logarithm to
306
+ """
307
+ self.operand = to_expr(operand)
308
+
309
+ def children(self):
310
+ return [self.operand]
311
+
312
+ def canonicalize(self) -> "Expr":
313
+ operand = self.operand.canonicalize()
314
+ return Log(operand)
315
+
316
+ def check_shape(self) -> Tuple[int, ...]:
317
+ """Log preserves the shape of its operand."""
318
+ return self.operand.check_shape()
319
+
320
+ def __repr__(self):
321
+ return f"log({self.operand!r})"
322
+
323
+
324
+ class Abs(Expr):
325
+ """Element-wise absolute value function for symbolic expressions.
326
+
327
+ Computes the absolute value (|x|) of each element in the operand. Preserves
328
+ the shape of the input expression. The absolute value function is convex
329
+ and DCP-compliant in CVXPy.
330
+
331
+ Attributes:
332
+ operand: Expression to apply absolute value to
333
+
334
+ Example:
335
+ Define an Abs expression:
336
+
337
+ x = Variable("x", shape=(3,))
338
+ abs_x = Abs(x) # Element-wise |x|
339
+ """
340
+
341
+ def __init__(self, operand):
342
+ """Initialize an absolute value operation.
343
+
344
+ Args:
345
+ operand: Expression to apply absolute value to
346
+ """
347
+ self.operand = to_expr(operand)
348
+
349
+ def children(self):
350
+ return [self.operand]
351
+
352
+ def canonicalize(self) -> "Expr":
353
+ operand = self.operand.canonicalize()
354
+ return Abs(operand)
355
+
356
+ def check_shape(self) -> Tuple[int, ...]:
357
+ """Abs preserves the shape of its operand."""
358
+ return self.operand.check_shape()
359
+
360
+ def __repr__(self):
361
+ return f"abs({self.operand!r})"
362
+
363
+
364
+ class Max(Expr):
365
+ """Element-wise maximum function for symbolic expressions.
366
+
367
+ Computes the element-wise maximum across two or more operands. Supports
368
+ broadcasting following NumPy rules. During canonicalization, nested Max
369
+ operations are flattened and constants are folded.
370
+
371
+ Attributes:
372
+ operands: List of expressions to compute maximum over
373
+
374
+ Example:
375
+ Define a Max expression:
376
+
377
+ x = Variable("x", shape=(3,))
378
+ y = Variable("y", shape=(3,))
379
+ max_xy = Max(x, y, 0) # Element-wise max(x, y, 0)
380
+ """
381
+
382
+ def __init__(self, *args):
383
+ """Initialize a maximum operation.
384
+
385
+ Args:
386
+ *args: Two or more expressions to compute maximum over
387
+
388
+ Raises:
389
+ ValueError: If fewer than two operands are provided
390
+ """
391
+ if len(args) < 2:
392
+ raise ValueError("Max requires two or more operands")
393
+ self.operands = [to_expr(a) for a in args]
394
+
395
+ def children(self):
396
+ return list(self.operands)
397
+
398
+ def canonicalize(self) -> "Expr":
399
+ """Canonicalize max: flatten nested Max, fold constants."""
400
+ from .expr import Constant
401
+
402
+ operands = []
403
+ const_vals = []
404
+
405
+ for op in self.operands:
406
+ c = op.canonicalize()
407
+ if isinstance(c, Max):
408
+ operands.extend(c.operands)
409
+ elif isinstance(c, Constant):
410
+ const_vals.append(c.value)
411
+ else:
412
+ operands.append(c)
413
+
414
+ # If we have constants, compute their max and keep it
415
+ if const_vals:
416
+ max_const = np.maximum.reduce(const_vals)
417
+ operands.append(Constant(max_const))
418
+
419
+ if not operands:
420
+ raise ValueError("Max must have at least one operand after canonicalization")
421
+ if len(operands) == 1:
422
+ return operands[0]
423
+ return Max(*operands)
424
+
425
+ def check_shape(self) -> Tuple[int, ...]:
426
+ """Max broadcasts shapes like NumPy."""
427
+ shapes = [child.check_shape() for child in self.children()]
428
+ try:
429
+ return np.broadcast_shapes(*shapes)
430
+ except ValueError as e:
431
+ raise ValueError(f"Max shapes not broadcastable: {shapes}") from e
432
+
433
+ def __repr__(self):
434
+ inner = ", ".join(repr(op) for op in self.operands)
435
+ return f"max({inner})"
436
+
437
+
438
+ # Penalty function building blocks
439
+ class PositivePart(Expr):
440
+ """Positive part function for symbolic expressions.
441
+
442
+ Computes max(x, 0) element-wise, effectively zeroing out negative values
443
+ while preserving positive values. This is also known as the ReLU (Rectified
444
+ Linear Unit) function and is commonly used as a penalty function building
445
+ block in optimization.
446
+
447
+ Attributes:
448
+ x: Expression to apply positive part function to
449
+
450
+ Example:
451
+ Define a PositivePart expression:
452
+
453
+ constraint_violation = x - 10
454
+ penalty = PositivePart(constraint_violation) # Penalizes x > 10
455
+ """
456
+
457
+ def __init__(self, x):
458
+ """Initialize a positive part operation.
459
+
460
+ Args:
461
+ x: Expression to apply positive part function to
462
+ """
463
+ self.x = to_expr(x)
464
+
465
+ def children(self):
466
+ return [self.x]
467
+
468
+ def canonicalize(self) -> "Expr":
469
+ x = self.x.canonicalize()
470
+ return PositivePart(x)
471
+
472
+ def check_shape(self) -> Tuple[int, ...]:
473
+ """pos(x) = max(x, 0) preserves the shape of x."""
474
+ return self.x.check_shape()
475
+
476
+ def __repr__(self):
477
+ return f"pos({self.x!r})"
478
+
479
+
480
+ class Huber(Expr):
481
+ """Huber penalty function for symbolic expressions.
482
+
483
+ The Huber penalty is a smooth approximation to the absolute value function
484
+ that is quadratic for small values (|x| < delta) and linear for large values
485
+ (|x| >= delta). This makes it more robust to outliers than squared penalties
486
+ while maintaining smoothness.
487
+
488
+ The Huber function is defined as:
489
+ - (x^2) / (2*delta) for |x| <= delta
490
+ - |x| - delta/2 for |x| > delta
491
+
492
+ Attributes:
493
+ x: Expression to apply Huber penalty to
494
+ delta: Threshold parameter controlling the transition point (default: 0.25)
495
+
496
+ Example:
497
+ Define a Huber penalty expression:
498
+
499
+ residual = y_measured - y_predicted
500
+ penalty = Huber(residual, delta=0.5)
501
+ """
502
+
503
+ def __init__(self, x, delta: float = 0.25):
504
+ """Initialize a Huber penalty operation.
505
+
506
+ Args:
507
+ x: Expression to apply Huber penalty to
508
+ delta: Threshold parameter for quadratic-to-linear transition (default: 0.25)
509
+ """
510
+ self.x = to_expr(x)
511
+ self.delta = float(delta)
512
+
513
+ def children(self):
514
+ return [self.x]
515
+
516
+ def canonicalize(self) -> "Expr":
517
+ """Canonicalize the operand but preserve delta parameter."""
518
+ x = self.x.canonicalize()
519
+ return Huber(x, delta=self.delta)
520
+
521
+ def check_shape(self) -> Tuple[int, ...]:
522
+ """Huber penalty preserves the shape of x."""
523
+ return self.x.check_shape()
524
+
525
+ def _hash_into(self, hasher: "hashlib._Hash") -> None:
526
+ """Hash Huber including its delta parameter.
527
+
528
+ Args:
529
+ hasher: A hashlib hash object to update
530
+ """
531
+ hasher.update(b"Huber")
532
+ # Hash delta as bytes
533
+ hasher.update(struct.pack(">d", self.delta))
534
+ # Hash the operand
535
+ self.x._hash_into(hasher)
536
+
537
+ def __repr__(self):
538
+ return f"huber({self.x!r}, delta={self.delta})"
539
+
540
+
541
+ class SmoothReLU(Expr):
542
+ """Smooth approximation to the ReLU (positive part) function.
543
+
544
+ Computes a smooth, differentiable approximation to max(x, 0) using the formula:
545
+ sqrt(max(x, 0)^2 + c^2) - c
546
+
547
+ The parameter c controls the smoothness: smaller values give a sharper
548
+ transition, while larger values produce a smoother approximation. As c
549
+ approaches 0, this converges to the standard ReLU function.
550
+
551
+ This is particularly useful in optimization contexts where smooth gradients
552
+ are required, such as in penalty methods for constraint handling (CTCS).
553
+
554
+ Attributes:
555
+ x: Expression to apply smooth ReLU to
556
+ c: Smoothing parameter (default: 1e-8)
557
+
558
+ Example:
559
+ Define a smooth ReLU expression:
560
+
561
+ constraint_violation = x - 10
562
+ penalty = SmoothReLU(constraint_violation, c=1e-6)
563
+ """
564
+
565
+ def __init__(self, x, c: float = 1e-8):
566
+ """Initialize a smooth ReLU operation.
567
+
568
+ Args:
569
+ x: Expression to apply smooth ReLU to
570
+ c: Smoothing parameter controlling transition sharpness (default: 1e-8)
571
+ """
572
+ self.x = to_expr(x)
573
+ self.c = float(c)
574
+
575
+ def children(self):
576
+ return [self.x]
577
+
578
+ def canonicalize(self) -> "Expr":
579
+ """Canonicalize the operand but preserve c parameter."""
580
+ x = self.x.canonicalize()
581
+ return SmoothReLU(x, c=self.c)
582
+
583
+ def check_shape(self) -> Tuple[int, ...]:
584
+ """Smooth ReLU preserves the shape of x."""
585
+ return self.x.check_shape()
586
+
587
+ def _hash_into(self, hasher: "hashlib._Hash") -> None:
588
+ """Hash SmoothReLU including its c parameter.
589
+
590
+ Args:
591
+ hasher: A hashlib hash object to update
592
+ """
593
+ hasher.update(b"SmoothReLU")
594
+ # Hash c as bytes
595
+ hasher.update(struct.pack(">d", self.c))
596
+ # Hash the operand
597
+ self.x._hash_into(hasher)
598
+
599
+ def __repr__(self):
600
+ return f"smooth_relu({self.x!r}, c={self.c})"
601
+
602
+
603
+ class LogSumExp(Expr):
604
+ """Log-sum-exp function for symbolic expressions.
605
+
606
+ Computes the log-sum-exp (LSE) of multiple operands, which is a smooth,
607
+ differentiable approximation to the maximum function. The log-sum-exp is
608
+ defined as:
609
+
610
+ logsumexp(x₁, x₂, ..., xₙ) = log(exp(x₁) + exp(x₂) + ... + exp(xₙ))
611
+
612
+ This function is numerically stable and is commonly used in optimization
613
+ as a smooth alternative to the non-differentiable maximum function. It
614
+ satisfies the inequality:
615
+
616
+ max(x₁, x₂, ..., xₙ) ≤ logsumexp(x₁, x₂, ..., xₙ) ≤ max(x₁, x₂, ..., xₙ) + log(n)
617
+
618
+ The log-sum-exp is convex and is particularly useful for:
619
+ - Smooth approximations of maximum constraints
620
+ - Soft maximum operations in neural networks
621
+ - Relaxing logical OR operations in STL specifications
622
+
623
+ Attributes:
624
+ operands: List of expressions to compute log-sum-exp over
625
+
626
+ Example:
627
+ Define a LogSumExp expression:
628
+
629
+ x = Variable("x", shape=(3,))
630
+ y = Variable("y", shape=(3,))
631
+ z = Variable("z", shape=(3,))
632
+ lse = LogSumExp(x, y, z) # Smooth approximation to max(x, y, z)
633
+
634
+ Use in STL relaxation:
635
+
636
+ import openscvx as ox
637
+ # Relax: Or(φ₁, φ₂) using log-sum-exp
638
+ phi1 = ox.Norm(x - goal1) - 0.5
639
+ phi2 = ox.Norm(x - goal2) - 0.5
640
+ relaxed_or = LogSumExp(phi1, phi2) >= 0
641
+ """
642
+
643
+ def __init__(self, *args):
644
+ """Initialize a log-sum-exp operation.
645
+
646
+ Args:
647
+ *args: Two or more expressions to compute log-sum-exp over
648
+
649
+ Raises:
650
+ ValueError: If fewer than two operands are provided
651
+ """
652
+ if len(args) < 2:
653
+ raise ValueError("LogSumExp requires two or more operands")
654
+ self.operands = [to_expr(a) for a in args]
655
+
656
+ def children(self):
657
+ return list(self.operands)
658
+
659
+ def canonicalize(self) -> "Expr":
660
+ """Canonicalize log-sum-exp: flatten nested LogSumExp, fold constants."""
661
+ from .expr import Constant
662
+
663
+ operands = []
664
+ const_vals = []
665
+
666
+ for op in self.operands:
667
+ c = op.canonicalize()
668
+ if isinstance(c, LogSumExp):
669
+ operands.extend(c.operands)
670
+ elif isinstance(c, Constant):
671
+ const_vals.append(c.value)
672
+ else:
673
+ operands.append(c)
674
+
675
+ # If we have constants, compute their log-sum-exp and keep it
676
+ if const_vals:
677
+ # For constants, we can compute logsumexp directly
678
+ # logsumexp(c1, c2, ..., cn) = log(sum(exp(ci)))
679
+ exp_vals = [np.exp(v) for v in const_vals]
680
+ lse_const = np.log(np.sum(exp_vals))
681
+ operands.append(Constant(lse_const))
682
+
683
+ if not operands:
684
+ raise ValueError("LogSumExp must have at least one operand after canonicalization")
685
+ if len(operands) == 1:
686
+ return operands[0]
687
+ return LogSumExp(*operands)
688
+
689
+ def check_shape(self) -> Tuple[int, ...]:
690
+ """LogSumExp broadcasts shapes like NumPy, preserving element-wise shape."""
691
+ shapes = [child.check_shape() for child in self.children()]
692
+ try:
693
+ return np.broadcast_shapes(*shapes)
694
+ except ValueError as e:
695
+ raise ValueError(f"LogSumExp shapes not broadcastable: {shapes}") from e
696
+
697
+ def __repr__(self):
698
+ inner = ", ".join(repr(op) for op in self.operands)
699
+ return f"logsumexp({inner})"